5b65779b15773252e8971c4f5809b1ad89c644cd
[firefly-linux-kernel-4.4.55.git] / crypto / algif_skcipher.c
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/mm.h>
22 #include <linux/module.h>
23 #include <linux/net.h>
24 #include <net/sock.h>
25
26 struct skcipher_sg_list {
27         struct list_head list;
28
29         int cur;
30
31         struct scatterlist sg[0];
32 };
33
34 struct skcipher_tfm {
35         struct crypto_skcipher *skcipher;
36         bool has_key;
37 };
38
39 struct skcipher_ctx {
40         struct list_head tsgl;
41         struct af_alg_sgl rsgl;
42
43         void *iv;
44
45         struct af_alg_completion completion;
46
47         atomic_t inflight;
48         unsigned used;
49
50         unsigned int len;
51         bool more;
52         bool merge;
53         bool enc;
54
55         struct skcipher_request req;
56 };
57
58 struct skcipher_async_rsgl {
59         struct af_alg_sgl sgl;
60         struct list_head list;
61 };
62
63 struct skcipher_async_req {
64         struct kiocb *iocb;
65         struct skcipher_async_rsgl first_sgl;
66         struct list_head list;
67         struct scatterlist *tsg;
68         atomic_t *inflight;
69         struct skcipher_request req;
70 };
71
72 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
73                       sizeof(struct scatterlist) - 1)
74
75 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
76 {
77         struct skcipher_async_rsgl *rsgl, *tmp;
78         struct scatterlist *sgl;
79         struct scatterlist *sg;
80         int i, n;
81
82         list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
83                 af_alg_free_sg(&rsgl->sgl);
84                 if (rsgl != &sreq->first_sgl)
85                         kfree(rsgl);
86         }
87         sgl = sreq->tsg;
88         n = sg_nents(sgl);
89         for_each_sg(sgl, sg, n, i)
90                 put_page(sg_page(sg));
91
92         kfree(sreq->tsg);
93 }
94
95 static void skcipher_async_cb(struct crypto_async_request *req, int err)
96 {
97         struct skcipher_async_req *sreq = req->data;
98         struct kiocb *iocb = sreq->iocb;
99
100         atomic_dec(sreq->inflight);
101         skcipher_free_async_sgls(sreq);
102         kzfree(sreq);
103         iocb->ki_complete(iocb, err, err);
104 }
105
106 static inline int skcipher_sndbuf(struct sock *sk)
107 {
108         struct alg_sock *ask = alg_sk(sk);
109         struct skcipher_ctx *ctx = ask->private;
110
111         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
112                           ctx->used, 0);
113 }
114
115 static inline bool skcipher_writable(struct sock *sk)
116 {
117         return PAGE_SIZE <= skcipher_sndbuf(sk);
118 }
119
120 static int skcipher_alloc_sgl(struct sock *sk)
121 {
122         struct alg_sock *ask = alg_sk(sk);
123         struct skcipher_ctx *ctx = ask->private;
124         struct skcipher_sg_list *sgl;
125         struct scatterlist *sg = NULL;
126
127         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
128         if (!list_empty(&ctx->tsgl))
129                 sg = sgl->sg;
130
131         if (!sg || sgl->cur >= MAX_SGL_ENTS) {
132                 sgl = sock_kmalloc(sk, sizeof(*sgl) +
133                                        sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
134                                    GFP_KERNEL);
135                 if (!sgl)
136                         return -ENOMEM;
137
138                 sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
139                 sgl->cur = 0;
140
141                 if (sg)
142                         sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
143
144                 list_add_tail(&sgl->list, &ctx->tsgl);
145         }
146
147         return 0;
148 }
149
150 static void skcipher_pull_sgl(struct sock *sk, int used, int put)
151 {
152         struct alg_sock *ask = alg_sk(sk);
153         struct skcipher_ctx *ctx = ask->private;
154         struct skcipher_sg_list *sgl;
155         struct scatterlist *sg;
156         int i;
157
158         while (!list_empty(&ctx->tsgl)) {
159                 sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
160                                        list);
161                 sg = sgl->sg;
162
163                 for (i = 0; i < sgl->cur; i++) {
164                         int plen = min_t(int, used, sg[i].length);
165
166                         if (!sg_page(sg + i))
167                                 continue;
168
169                         sg[i].length -= plen;
170                         sg[i].offset += plen;
171
172                         used -= plen;
173                         ctx->used -= plen;
174
175                         if (sg[i].length)
176                                 return;
177                         if (put)
178                                 put_page(sg_page(sg + i));
179                         sg_assign_page(sg + i, NULL);
180                 }
181
182                 list_del(&sgl->list);
183                 sock_kfree_s(sk, sgl,
184                              sizeof(*sgl) + sizeof(sgl->sg[0]) *
185                                             (MAX_SGL_ENTS + 1));
186         }
187
188         if (!ctx->used)
189                 ctx->merge = 0;
190 }
191
192 static void skcipher_free_sgl(struct sock *sk)
193 {
194         struct alg_sock *ask = alg_sk(sk);
195         struct skcipher_ctx *ctx = ask->private;
196
197         skcipher_pull_sgl(sk, ctx->used, 1);
198 }
199
200 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
201 {
202         long timeout;
203         DEFINE_WAIT(wait);
204         int err = -ERESTARTSYS;
205
206         if (flags & MSG_DONTWAIT)
207                 return -EAGAIN;
208
209         sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
210
211         for (;;) {
212                 if (signal_pending(current))
213                         break;
214                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
215                 timeout = MAX_SCHEDULE_TIMEOUT;
216                 if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
217                         err = 0;
218                         break;
219                 }
220         }
221         finish_wait(sk_sleep(sk), &wait);
222
223         return err;
224 }
225
226 static void skcipher_wmem_wakeup(struct sock *sk)
227 {
228         struct socket_wq *wq;
229
230         if (!skcipher_writable(sk))
231                 return;
232
233         rcu_read_lock();
234         wq = rcu_dereference(sk->sk_wq);
235         if (wq_has_sleeper(wq))
236                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
237                                                            POLLRDNORM |
238                                                            POLLRDBAND);
239         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
240         rcu_read_unlock();
241 }
242
243 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
244 {
245         struct alg_sock *ask = alg_sk(sk);
246         struct skcipher_ctx *ctx = ask->private;
247         long timeout;
248         DEFINE_WAIT(wait);
249         int err = -ERESTARTSYS;
250
251         if (flags & MSG_DONTWAIT) {
252                 return -EAGAIN;
253         }
254
255         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
256
257         for (;;) {
258                 if (signal_pending(current))
259                         break;
260                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
261                 timeout = MAX_SCHEDULE_TIMEOUT;
262                 if (sk_wait_event(sk, &timeout, ctx->used)) {
263                         err = 0;
264                         break;
265                 }
266         }
267         finish_wait(sk_sleep(sk), &wait);
268
269         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
270
271         return err;
272 }
273
274 static void skcipher_data_wakeup(struct sock *sk)
275 {
276         struct alg_sock *ask = alg_sk(sk);
277         struct skcipher_ctx *ctx = ask->private;
278         struct socket_wq *wq;
279
280         if (!ctx->used)
281                 return;
282
283         rcu_read_lock();
284         wq = rcu_dereference(sk->sk_wq);
285         if (wq_has_sleeper(wq))
286                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
287                                                            POLLRDNORM |
288                                                            POLLRDBAND);
289         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
290         rcu_read_unlock();
291 }
292
293 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
294                             size_t size)
295 {
296         struct sock *sk = sock->sk;
297         struct alg_sock *ask = alg_sk(sk);
298         struct skcipher_ctx *ctx = ask->private;
299         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
300         unsigned ivsize = crypto_skcipher_ivsize(tfm);
301         struct skcipher_sg_list *sgl;
302         struct af_alg_control con = {};
303         long copied = 0;
304         bool enc = 0;
305         bool init = 0;
306         int err;
307         int i;
308
309         if (msg->msg_controllen) {
310                 err = af_alg_cmsg_send(msg, &con);
311                 if (err)
312                         return err;
313
314                 init = 1;
315                 switch (con.op) {
316                 case ALG_OP_ENCRYPT:
317                         enc = 1;
318                         break;
319                 case ALG_OP_DECRYPT:
320                         enc = 0;
321                         break;
322                 default:
323                         return -EINVAL;
324                 }
325
326                 if (con.iv && con.iv->ivlen != ivsize)
327                         return -EINVAL;
328         }
329
330         err = -EINVAL;
331
332         lock_sock(sk);
333         if (!ctx->more && ctx->used)
334                 goto unlock;
335
336         if (init) {
337                 ctx->enc = enc;
338                 if (con.iv)
339                         memcpy(ctx->iv, con.iv->iv, ivsize);
340         }
341
342         while (size) {
343                 struct scatterlist *sg;
344                 unsigned long len = size;
345                 int plen;
346
347                 if (ctx->merge) {
348                         sgl = list_entry(ctx->tsgl.prev,
349                                          struct skcipher_sg_list, list);
350                         sg = sgl->sg + sgl->cur - 1;
351                         len = min_t(unsigned long, len,
352                                     PAGE_SIZE - sg->offset - sg->length);
353
354                         err = memcpy_from_msg(page_address(sg_page(sg)) +
355                                               sg->offset + sg->length,
356                                               msg, len);
357                         if (err)
358                                 goto unlock;
359
360                         sg->length += len;
361                         ctx->merge = (sg->offset + sg->length) &
362                                      (PAGE_SIZE - 1);
363
364                         ctx->used += len;
365                         copied += len;
366                         size -= len;
367                         continue;
368                 }
369
370                 if (!skcipher_writable(sk)) {
371                         err = skcipher_wait_for_wmem(sk, msg->msg_flags);
372                         if (err)
373                                 goto unlock;
374                 }
375
376                 len = min_t(unsigned long, len, skcipher_sndbuf(sk));
377
378                 err = skcipher_alloc_sgl(sk);
379                 if (err)
380                         goto unlock;
381
382                 sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
383                 sg = sgl->sg;
384                 if (sgl->cur)
385                         sg_unmark_end(sg + sgl->cur - 1);
386                 do {
387                         i = sgl->cur;
388                         plen = min_t(int, len, PAGE_SIZE);
389
390                         sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
391                         err = -ENOMEM;
392                         if (!sg_page(sg + i))
393                                 goto unlock;
394
395                         err = memcpy_from_msg(page_address(sg_page(sg + i)),
396                                               msg, plen);
397                         if (err) {
398                                 __free_page(sg_page(sg + i));
399                                 sg_assign_page(sg + i, NULL);
400                                 goto unlock;
401                         }
402
403                         sg[i].length = plen;
404                         len -= plen;
405                         ctx->used += plen;
406                         copied += plen;
407                         size -= plen;
408                         sgl->cur++;
409                 } while (len && sgl->cur < MAX_SGL_ENTS);
410
411                 if (!size)
412                         sg_mark_end(sg + sgl->cur - 1);
413
414                 ctx->merge = plen & (PAGE_SIZE - 1);
415         }
416
417         err = 0;
418
419         ctx->more = msg->msg_flags & MSG_MORE;
420
421 unlock:
422         skcipher_data_wakeup(sk);
423         release_sock(sk);
424
425         return copied ?: err;
426 }
427
428 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
429                                  int offset, size_t size, int flags)
430 {
431         struct sock *sk = sock->sk;
432         struct alg_sock *ask = alg_sk(sk);
433         struct skcipher_ctx *ctx = ask->private;
434         struct skcipher_sg_list *sgl;
435         int err = -EINVAL;
436
437         if (flags & MSG_SENDPAGE_NOTLAST)
438                 flags |= MSG_MORE;
439
440         lock_sock(sk);
441         if (!ctx->more && ctx->used)
442                 goto unlock;
443
444         if (!size)
445                 goto done;
446
447         if (!skcipher_writable(sk)) {
448                 err = skcipher_wait_for_wmem(sk, flags);
449                 if (err)
450                         goto unlock;
451         }
452
453         err = skcipher_alloc_sgl(sk);
454         if (err)
455                 goto unlock;
456
457         ctx->merge = 0;
458         sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
459
460         if (sgl->cur)
461                 sg_unmark_end(sgl->sg + sgl->cur - 1);
462
463         sg_mark_end(sgl->sg + sgl->cur);
464         get_page(page);
465         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
466         sgl->cur++;
467         ctx->used += size;
468
469 done:
470         ctx->more = flags & MSG_MORE;
471
472 unlock:
473         skcipher_data_wakeup(sk);
474         release_sock(sk);
475
476         return err ?: size;
477 }
478
479 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
480 {
481         struct skcipher_sg_list *sgl;
482         struct scatterlist *sg;
483         int nents = 0;
484
485         list_for_each_entry(sgl, &ctx->tsgl, list) {
486                 sg = sgl->sg;
487
488                 while (!sg->length)
489                         sg++;
490
491                 nents += sg_nents(sg);
492         }
493         return nents;
494 }
495
496 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
497                                   int flags)
498 {
499         struct sock *sk = sock->sk;
500         struct alg_sock *ask = alg_sk(sk);
501         struct sock *psk = ask->parent;
502         struct alg_sock *pask = alg_sk(psk);
503         struct skcipher_ctx *ctx = ask->private;
504         struct skcipher_tfm *skc = pask->private;
505         struct crypto_skcipher *tfm = skc->skcipher;
506         struct skcipher_sg_list *sgl;
507         struct scatterlist *sg;
508         struct skcipher_async_req *sreq;
509         struct skcipher_request *req;
510         struct skcipher_async_rsgl *last_rsgl = NULL;
511         unsigned int txbufs = 0, len = 0, tx_nents = skcipher_all_sg_nents(ctx);
512         unsigned int reqsize = crypto_skcipher_reqsize(tfm);
513         unsigned int ivsize = crypto_skcipher_ivsize(tfm);
514         int err = -ENOMEM;
515         bool mark = false;
516         char *iv;
517
518         sreq = kzalloc(sizeof(*sreq) + reqsize + ivsize, GFP_KERNEL);
519         if (unlikely(!sreq))
520                 goto out;
521
522         req = &sreq->req;
523         iv = (char *)(req + 1) + reqsize;
524         sreq->iocb = msg->msg_iocb;
525         INIT_LIST_HEAD(&sreq->list);
526         sreq->inflight = &ctx->inflight;
527
528         lock_sock(sk);
529         sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
530         if (unlikely(!sreq->tsg))
531                 goto unlock;
532         sg_init_table(sreq->tsg, tx_nents);
533         memcpy(iv, ctx->iv, ivsize);
534         skcipher_request_set_tfm(req, tfm);
535         skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
536                                       skcipher_async_cb, sreq);
537
538         while (iov_iter_count(&msg->msg_iter)) {
539                 struct skcipher_async_rsgl *rsgl;
540                 int used;
541
542                 if (!ctx->used) {
543                         err = skcipher_wait_for_data(sk, flags);
544                         if (err)
545                                 goto free;
546                 }
547                 sgl = list_first_entry(&ctx->tsgl,
548                                        struct skcipher_sg_list, list);
549                 sg = sgl->sg;
550
551                 while (!sg->length)
552                         sg++;
553
554                 used = min_t(unsigned long, ctx->used,
555                              iov_iter_count(&msg->msg_iter));
556                 used = min_t(unsigned long, used, sg->length);
557
558                 if (txbufs == tx_nents) {
559                         struct scatterlist *tmp;
560                         int x;
561                         /* Ran out of tx slots in async request
562                          * need to expand */
563                         tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
564                                       GFP_KERNEL);
565                         if (!tmp)
566                                 goto free;
567
568                         sg_init_table(tmp, tx_nents * 2);
569                         for (x = 0; x < tx_nents; x++)
570                                 sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
571                                             sreq->tsg[x].length,
572                                             sreq->tsg[x].offset);
573                         kfree(sreq->tsg);
574                         sreq->tsg = tmp;
575                         tx_nents *= 2;
576                         mark = true;
577                 }
578                 /* Need to take over the tx sgl from ctx
579                  * to the asynch req - these sgls will be freed later */
580                 sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
581                             sg->offset);
582
583                 if (list_empty(&sreq->list)) {
584                         rsgl = &sreq->first_sgl;
585                         list_add_tail(&rsgl->list, &sreq->list);
586                 } else {
587                         rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
588                         if (!rsgl) {
589                                 err = -ENOMEM;
590                                 goto free;
591                         }
592                         list_add_tail(&rsgl->list, &sreq->list);
593                 }
594
595                 used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
596                 err = used;
597                 if (used < 0)
598                         goto free;
599                 if (last_rsgl)
600                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
601
602                 last_rsgl = rsgl;
603                 len += used;
604                 skcipher_pull_sgl(sk, used, 0);
605                 iov_iter_advance(&msg->msg_iter, used);
606         }
607
608         if (mark)
609                 sg_mark_end(sreq->tsg + txbufs - 1);
610
611         skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
612                                    len, iv);
613         err = ctx->enc ? crypto_skcipher_encrypt(req) :
614                          crypto_skcipher_decrypt(req);
615         if (err == -EINPROGRESS) {
616                 atomic_inc(&ctx->inflight);
617                 err = -EIOCBQUEUED;
618                 sreq = NULL;
619                 goto unlock;
620         }
621 free:
622         skcipher_free_async_sgls(sreq);
623 unlock:
624         skcipher_wmem_wakeup(sk);
625         release_sock(sk);
626         kzfree(sreq);
627 out:
628         return err;
629 }
630
631 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
632                                  int flags)
633 {
634         struct sock *sk = sock->sk;
635         struct alg_sock *ask = alg_sk(sk);
636         struct skcipher_ctx *ctx = ask->private;
637         unsigned bs = crypto_skcipher_blocksize(crypto_skcipher_reqtfm(
638                 &ctx->req));
639         struct skcipher_sg_list *sgl;
640         struct scatterlist *sg;
641         int err = -EAGAIN;
642         int used;
643         long copied = 0;
644
645         lock_sock(sk);
646         while (msg_data_left(msg)) {
647                 if (!ctx->used) {
648                         err = skcipher_wait_for_data(sk, flags);
649                         if (err)
650                                 goto unlock;
651                 }
652
653                 used = min_t(unsigned long, ctx->used, msg_data_left(msg));
654
655                 used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
656                 err = used;
657                 if (err < 0)
658                         goto unlock;
659
660                 if (ctx->more || used < ctx->used)
661                         used -= used % bs;
662
663                 err = -EINVAL;
664                 if (!used)
665                         goto free;
666
667                 sgl = list_first_entry(&ctx->tsgl,
668                                        struct skcipher_sg_list, list);
669                 sg = sgl->sg;
670
671                 while (!sg->length)
672                         sg++;
673
674                 skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
675                                            ctx->iv);
676
677                 err = af_alg_wait_for_completion(
678                                 ctx->enc ?
679                                         crypto_skcipher_encrypt(&ctx->req) :
680                                         crypto_skcipher_decrypt(&ctx->req),
681                                 &ctx->completion);
682
683 free:
684                 af_alg_free_sg(&ctx->rsgl);
685
686                 if (err)
687                         goto unlock;
688
689                 copied += used;
690                 skcipher_pull_sgl(sk, used, 1);
691                 iov_iter_advance(&msg->msg_iter, used);
692         }
693
694         err = 0;
695
696 unlock:
697         skcipher_wmem_wakeup(sk);
698         release_sock(sk);
699
700         return copied ?: err;
701 }
702
703 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
704                             size_t ignored, int flags)
705 {
706         return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
707                 skcipher_recvmsg_async(sock, msg, flags) :
708                 skcipher_recvmsg_sync(sock, msg, flags);
709 }
710
711 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
712                                   poll_table *wait)
713 {
714         struct sock *sk = sock->sk;
715         struct alg_sock *ask = alg_sk(sk);
716         struct skcipher_ctx *ctx = ask->private;
717         unsigned int mask;
718
719         sock_poll_wait(file, sk_sleep(sk), wait);
720         mask = 0;
721
722         if (ctx->used)
723                 mask |= POLLIN | POLLRDNORM;
724
725         if (skcipher_writable(sk))
726                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
727
728         return mask;
729 }
730
731 static struct proto_ops algif_skcipher_ops = {
732         .family         =       PF_ALG,
733
734         .connect        =       sock_no_connect,
735         .socketpair     =       sock_no_socketpair,
736         .getname        =       sock_no_getname,
737         .ioctl          =       sock_no_ioctl,
738         .listen         =       sock_no_listen,
739         .shutdown       =       sock_no_shutdown,
740         .getsockopt     =       sock_no_getsockopt,
741         .mmap           =       sock_no_mmap,
742         .bind           =       sock_no_bind,
743         .accept         =       sock_no_accept,
744         .setsockopt     =       sock_no_setsockopt,
745
746         .release        =       af_alg_release,
747         .sendmsg        =       skcipher_sendmsg,
748         .sendpage       =       skcipher_sendpage,
749         .recvmsg        =       skcipher_recvmsg,
750         .poll           =       skcipher_poll,
751 };
752
753 static int skcipher_check_key(struct socket *sock)
754 {
755         int err = 0;
756         struct sock *psk;
757         struct alg_sock *pask;
758         struct skcipher_tfm *tfm;
759         struct sock *sk = sock->sk;
760         struct alg_sock *ask = alg_sk(sk);
761
762         lock_sock(sk);
763         if (ask->refcnt)
764                 goto unlock_child;
765
766         psk = ask->parent;
767         pask = alg_sk(ask->parent);
768         tfm = pask->private;
769
770         err = -ENOKEY;
771         lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
772         if (!tfm->has_key)
773                 goto unlock;
774
775         if (!pask->refcnt++)
776                 sock_hold(psk);
777
778         ask->refcnt = 1;
779         sock_put(psk);
780
781         err = 0;
782
783 unlock:
784         release_sock(psk);
785 unlock_child:
786         release_sock(sk);
787
788         return err;
789 }
790
791 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
792                                   size_t size)
793 {
794         int err;
795
796         err = skcipher_check_key(sock);
797         if (err)
798                 return err;
799
800         return skcipher_sendmsg(sock, msg, size);
801 }
802
803 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
804                                        int offset, size_t size, int flags)
805 {
806         int err;
807
808         err = skcipher_check_key(sock);
809         if (err)
810                 return err;
811
812         return skcipher_sendpage(sock, page, offset, size, flags);
813 }
814
815 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
816                                   size_t ignored, int flags)
817 {
818         int err;
819
820         err = skcipher_check_key(sock);
821         if (err)
822                 return err;
823
824         return skcipher_recvmsg(sock, msg, ignored, flags);
825 }
826
827 static struct proto_ops algif_skcipher_ops_nokey = {
828         .family         =       PF_ALG,
829
830         .connect        =       sock_no_connect,
831         .socketpair     =       sock_no_socketpair,
832         .getname        =       sock_no_getname,
833         .ioctl          =       sock_no_ioctl,
834         .listen         =       sock_no_listen,
835         .shutdown       =       sock_no_shutdown,
836         .getsockopt     =       sock_no_getsockopt,
837         .mmap           =       sock_no_mmap,
838         .bind           =       sock_no_bind,
839         .accept         =       sock_no_accept,
840         .setsockopt     =       sock_no_setsockopt,
841
842         .release        =       af_alg_release,
843         .sendmsg        =       skcipher_sendmsg_nokey,
844         .sendpage       =       skcipher_sendpage_nokey,
845         .recvmsg        =       skcipher_recvmsg_nokey,
846         .poll           =       skcipher_poll,
847 };
848
849 static void *skcipher_bind(const char *name, u32 type, u32 mask)
850 {
851         struct skcipher_tfm *tfm;
852         struct crypto_skcipher *skcipher;
853
854         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
855         if (!tfm)
856                 return ERR_PTR(-ENOMEM);
857
858         skcipher = crypto_alloc_skcipher(name, type, mask);
859         if (IS_ERR(skcipher)) {
860                 kfree(tfm);
861                 return ERR_CAST(skcipher);
862         }
863
864         tfm->skcipher = skcipher;
865
866         return tfm;
867 }
868
869 static void skcipher_release(void *private)
870 {
871         struct skcipher_tfm *tfm = private;
872
873         crypto_free_skcipher(tfm->skcipher);
874         kfree(tfm);
875 }
876
877 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
878 {
879         struct skcipher_tfm *tfm = private;
880         int err;
881
882         err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
883         tfm->has_key = !err;
884
885         return err;
886 }
887
888 static void skcipher_wait(struct sock *sk)
889 {
890         struct alg_sock *ask = alg_sk(sk);
891         struct skcipher_ctx *ctx = ask->private;
892         int ctr = 0;
893
894         while (atomic_read(&ctx->inflight) && ctr++ < 100)
895                 msleep(100);
896 }
897
898 static void skcipher_sock_destruct(struct sock *sk)
899 {
900         struct alg_sock *ask = alg_sk(sk);
901         struct skcipher_ctx *ctx = ask->private;
902         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
903
904         if (atomic_read(&ctx->inflight))
905                 skcipher_wait(sk);
906
907         skcipher_free_sgl(sk);
908         sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
909         sock_kfree_s(sk, ctx, ctx->len);
910         af_alg_release_parent(sk);
911 }
912
913 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
914 {
915         struct skcipher_ctx *ctx;
916         struct alg_sock *ask = alg_sk(sk);
917         struct skcipher_tfm *tfm = private;
918         struct crypto_skcipher *skcipher = tfm->skcipher;
919         unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
920
921         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
922         if (!ctx)
923                 return -ENOMEM;
924
925         ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
926                                GFP_KERNEL);
927         if (!ctx->iv) {
928                 sock_kfree_s(sk, ctx, len);
929                 return -ENOMEM;
930         }
931
932         memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
933
934         INIT_LIST_HEAD(&ctx->tsgl);
935         ctx->len = len;
936         ctx->used = 0;
937         ctx->more = 0;
938         ctx->merge = 0;
939         ctx->enc = 0;
940         atomic_set(&ctx->inflight, 0);
941         af_alg_init_completion(&ctx->completion);
942
943         ask->private = ctx;
944
945         skcipher_request_set_tfm(&ctx->req, skcipher);
946         skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
947                                       af_alg_complete, &ctx->completion);
948
949         sk->sk_destruct = skcipher_sock_destruct;
950
951         return 0;
952 }
953
954 static int skcipher_accept_parent(void *private, struct sock *sk)
955 {
956         struct skcipher_tfm *tfm = private;
957
958         if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
959                 return -ENOKEY;
960
961         return skcipher_accept_parent_nokey(private, sk);
962 }
963
964 static const struct af_alg_type algif_type_skcipher = {
965         .bind           =       skcipher_bind,
966         .release        =       skcipher_release,
967         .setkey         =       skcipher_setkey,
968         .accept         =       skcipher_accept_parent,
969         .accept_nokey   =       skcipher_accept_parent_nokey,
970         .ops            =       &algif_skcipher_ops,
971         .ops_nokey      =       &algif_skcipher_ops_nokey,
972         .name           =       "skcipher",
973         .owner          =       THIS_MODULE
974 };
975
976 static int __init algif_skcipher_init(void)
977 {
978         return af_alg_register_type(&algif_type_skcipher);
979 }
980
981 static void __exit algif_skcipher_exit(void)
982 {
983         int err = af_alg_unregister_type(&algif_type_skcipher);
984         BUG_ON(err);
985 }
986
987 module_init(algif_skcipher_init);
988 module_exit(algif_skcipher_exit);
989 MODULE_LICENSE("GPL");