diff options
Diffstat (limited to 'web/server/h2o/libh2o/deps/picotls/deps/micro-ecc/scripts/square_avr.py')
-rwxr-xr-x | web/server/h2o/libh2o/deps/picotls/deps/micro-ecc/scripts/square_avr.py | 327 |
1 files changed, 327 insertions, 0 deletions
diff --git a/web/server/h2o/libh2o/deps/picotls/deps/micro-ecc/scripts/square_avr.py b/web/server/h2o/libh2o/deps/picotls/deps/micro-ecc/scripts/square_avr.py new file mode 100755 index 000000000..6571c3b3c --- /dev/null +++ b/web/server/h2o/libh2o/deps/picotls/deps/micro-ecc/scripts/square_avr.py @@ -0,0 +1,327 @@ +#!/usr/bin/env python + +import sys + +if len(sys.argv) < 2: + print "Provide the integer size in bytes" + sys.exit(1) + +size = int(sys.argv[1]) + +if size > 40: + print "This script doesn't work with integer size %s due to laziness" % (size) + sys.exit(1) + +init_size = size - 20 +if size < 20: + init_size = 0 + +def rg(i): + return i + 2 + +def lo(i): + return i + 2 + +def hi(i): + return i + 12 + +def emit(line, *args): + s = '"' + line + r' \n\t"' + print s % args + +#### set up registers +zero = "r25" +emit("ldi %s, 0", zero) # zero register + +if init_size > 0: + emit("movw r28, r26") # y = x + h = (init_size + 1)//2 + + for i in xrange(h): + emit("ld r%s, x+", lo(i)) + emit("adiw r28, %s", size - init_size) # move y to other end + for i in xrange(h): + emit("ld r%s, y+", hi(i)) + + emit("adiw r30, %s", size - init_size) # move z + + if init_size == 1: + emit("mul %s, %s", lo(0), hi(0)) + emit("st z+, r0") + emit("st z+, r1") + else: + #### first one + print "" + emit("ldi r23, 0") + emit("mul %s, %s", lo(0), hi(0)) + emit("st z+, r0") + emit("mov r22, r1") + print "" + + #### rest of initial block, with moving accumulator registers + acc = [22, 23, 24] + for r in xrange(1, h): + emit("ldi r%s, 0", acc[2]) + for i in xrange(0, (r+2)//2): + emit("mul r%s, r%s", lo(i), hi(r - i)) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + emit("st z+, r%s", acc[0]) + print "" + acc = acc[1:] + acc[:1] + + lo_r = range(2, 2 + h) + hi_r = range(12, 12 + h) + + # now we need to start loading more from the high end + for r in xrange(h, init_size): + hi_r = hi_r[1:] + hi_r[:1] + emit("ld r%s, y+", hi_r[h-1]) + + emit("ldi r%s, 0", acc[2]) + for i in xrange(0, (r+2)//2): + emit("mul r%s, r%s", lo(i), hi_r[h - 1 - i]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + emit("st z+, r%s", acc[0]) + print "" + acc = acc[1:] + acc[:1] + + # loaded all of the high end bytes; now need to start loading the rest of the low end + for r in xrange(1, init_size-h): + lo_r = lo_r[1:] + lo_r[:1] + emit("ld r%s, x+", lo_r[h-1]) + + emit("ldi r%s, 0", acc[2]) + for i in xrange(0, (init_size+1 - r)//2): + emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + emit("st z+, r%s", acc[0]) + print "" + acc = acc[1:] + acc[:1] + + lo_r = lo_r[1:] + lo_r[:1] + emit("ld r%s, x+", lo_r[h-1]) + + # now we have loaded everything, and we just need to finish the last corner + for r in xrange(init_size-h, init_size-1): + emit("ldi r%s, 0", acc[2]) + for i in xrange(0, (init_size+1 - r)//2): + emit("mul r%s, r%s", lo_r[i], hi_r[h - 1 - i]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + emit("st z+, r%s", acc[0]) + print "" + acc = acc[1:] + acc[:1] + lo_r = lo_r[1:] + lo_r[:1] # make the indexing easy + + emit("mul r%s, r%s", lo_r[0], hi_r[h - 1]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("st z+, r%s", acc[0]) + emit("st z+, r%s", acc[1]) + print "" + emit("sbiw r26, %s", init_size) # reset x + emit("sbiw r30, %s", size + init_size) # reset z + +# TODO you could do more rows of size 20 here if your integers are larger than 40 bytes + +s = size - init_size + +for i in xrange(s): + emit("ld r%s, x+", rg(i)) + +#### first few columns +# NOTE: this is only valid if size >= 3 +print "" +emit("ldi r23, 0") +emit("mul r%s, r%s", rg(0), rg(0)) +emit("st z+, r0") +emit("mov r22, r1") +print "" +emit("ldi r24, 0") +emit("mul r%s, r%s", rg(0), rg(1)) +emit("add r22, r0") +emit("adc r23, r1") +emit("adc r24, %s", zero) +emit("add r22, r0") +emit("adc r23, r1") +emit("adc r24, %s", zero) +emit("st z+, r22") +print "" +emit("ldi r22, 0") +emit("mul r%s, r%s", rg(0), rg(2)) +emit("add r23, r0") +emit("adc r24, r1") +emit("adc r22, %s", zero) +emit("add r23, r0") +emit("adc r24, r1") +emit("adc r22, %s", zero) +emit("mul r%s, r%s", rg(1), rg(1)) +emit("add r23, r0") +emit("adc r24, r1") +emit("adc r22, %s", zero) +emit("st z+, r23") +print "" + +acc = [23, 24, 22] +old_acc = [28, 29] +for i in xrange(3, s): + emit("ldi r%s, 0", old_acc[1]) + tmp = [acc[1], acc[2]] + acc = [acc[0], old_acc[0], old_acc[1]] + old_acc = tmp + + # gather non-equal words + emit("mul r%s, r%s", rg(0), rg(i)) + emit("mov r%s, r0", acc[0]) + emit("mov r%s, r1", acc[1]) + for j in xrange(1, (i+1)//2): + emit("mul r%s, r%s", rg(j), rg(i-j)) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + # multiply by 2 + emit("lsl r%s", acc[0]) + emit("rol r%s", acc[1]) + emit("rol r%s", acc[2]) + + # add equal word (if any) + if ((i+1) % 2) != 0: + emit("mul r%s, r%s", rg(i//2), rg(i//2)) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # add old accumulator + emit("add r%s, r%s", acc[0], old_acc[0]) + emit("adc r%s, r%s", acc[1], old_acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # store + emit("st z+, r%s", acc[0]) + print "" + +regs = range(2, 22) +for i in xrange(init_size): + regs = regs[1:] + regs[:1] + emit("ld r%s, x+", regs[19]) + + for limit in [18, 19]: + emit("ldi r%s, 0", old_acc[1]) + tmp = [acc[1], acc[2]] + acc = [acc[0], old_acc[0], old_acc[1]] + old_acc = tmp + + # gather non-equal words + emit("mul r%s, r%s", regs[0], regs[limit]) + emit("mov r%s, r0", acc[0]) + emit("mov r%s, r1", acc[1]) + for j in xrange(1, (limit+1)//2): + emit("mul r%s, r%s", regs[j], regs[limit-j]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + + emit("ld r0, z") # load stored value from initial block, and add to accumulator (note z does not increment) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r25", acc[1]) + emit("adc r%s, r25", acc[2]) + + # multiply by 2 + emit("lsl r%s", acc[0]) + emit("rol r%s", acc[1]) + emit("rol r%s", acc[2]) + + # add equal word + if limit == 18: + emit("mul r%s, r%s", regs[9], regs[9]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # add old accumulator + emit("add r%s, r%s", acc[0], old_acc[0]) + emit("adc r%s, r%s", acc[1], old_acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # store + emit("st z+, r%s", acc[0]) + print "" + +for i in xrange(1, s-3): + emit("ldi r%s, 0", old_acc[1]) + tmp = [acc[1], acc[2]] + acc = [acc[0], old_acc[0], old_acc[1]] + old_acc = tmp + + # gather non-equal words + emit("mul r%s, r%s", regs[i], regs[s - 1]) + emit("mov r%s, r0", acc[0]) + emit("mov r%s, r1", acc[1]) + for j in xrange(1, (s-i)//2): + emit("mul r%s, r%s", regs[i+j], regs[s - 1 - j]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + # multiply by 2 + emit("lsl r%s", acc[0]) + emit("rol r%s", acc[1]) + emit("rol r%s", acc[2]) + + # add equal word (if any) + if ((s-i) % 2) != 0: + emit("mul r%s, r%s", regs[i + (s-i)//2], regs[i + (s-i)//2]) + emit("add r%s, r0", acc[0]) + emit("adc r%s, r1", acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # add old accumulator + emit("add r%s, r%s", acc[0], old_acc[0]) + emit("adc r%s, r%s", acc[1], old_acc[1]) + emit("adc r%s, %s", acc[2], zero) + + # store + emit("st z+, r%s", acc[0]) + print "" + +acc = acc[1:] + acc[:1] +emit("ldi r%s, 0", acc[2]) +emit("mul r%s, r%s", regs[17], regs[19]) +emit("add r%s, r0", acc[0]) +emit("adc r%s, r1", acc[1]) +emit("adc r%s, %s", acc[2], zero) +emit("add r%s, r0", acc[0]) +emit("adc r%s, r1", acc[1]) +emit("adc r%s, %s", acc[2], zero) +emit("mul r%s, r%s", regs[18], regs[18]) +emit("add r%s, r0", acc[0]) +emit("adc r%s, r1", acc[1]) +emit("adc r%s, %s", acc[2], zero) +emit("st z+, r%s", acc[0]) +print "" + +acc = acc[1:] + acc[:1] +emit("ldi r%s, 0", acc[2]) +emit("mul r%s, r%s", regs[18], regs[19]) +emit("add r%s, r0", acc[0]) +emit("adc r%s, r1", acc[1]) +emit("adc r%s, %s", acc[2], zero) +emit("add r%s, r0", acc[0]) +emit("adc r%s, r1", acc[1]) +emit("adc r%s, %s", acc[2], zero) +emit("st z+, r%s", acc[0]) +print "" + +emit("mul r%s, r%s", regs[19], regs[19]) +emit("add r%s, r0", acc[1]) +emit("adc r%s, r1", acc[2]) +emit("st z+, r%s", acc[1]) + +emit("st z+, r%s", acc[2]) +emit("eor r1, r1") |