1 /* Program for computing integer expressions using the GNU Multiple Precision
2    Arithmetic Library.
3 
4 Copyright 1997, 1999-2002, 2005, 2008, 2012, 2015 Free Software Foundation, Inc.
5 
6 This program is free software; you can redistribute it and/or modify it under
7 the terms of the GNU General Public License as published by the Free Software
8 Foundation; either version 3 of the License, or (at your option) any later
9 version.
10 
11 This program is distributed in the hope that it will be useful, but WITHOUT ANY
12 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13 PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14 
15 You should have received a copy of the GNU General Public License along with
16 this program.  If not, see https://www.gnu.org/licenses/.  */
17 
18 
19 /* This expressions evaluator works by building an expression tree (using a
20    recursive descent parser) which is then evaluated.  The expression tree is
21    useful since we want to optimize certain expressions (like a^b % c).
22 
23    Usage: pexpr [options] expr ...
24    (Assuming you called the executable `pexpr' of course.)
25 
26    Command line options:
27 
28    -b        print output in binary
29    -o        print output in octal
30    -d        print output in decimal (the default)
31    -x        print output in hexadecimal
32    -b<NUM>   print output in base NUM
33    -t        print timing information
34    -html     output html
35    -wml      output wml
36    -split    split long lines each 80th digit
37 */
38 
39 /* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
40    use up extensive resources (cpu, memory).  Useful for the GMP demo on the
41    GMP web site, since we cannot load the server too much.  */
42 
43 #include "pexpr-config.h"
44 
45 #include <string.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <setjmp.h>
49 #include <signal.h>
50 #include <ctype.h>
51 
52 #include <time.h>
53 #include <sys/types.h>
54 #include <sys/time.h>
55 #if HAVE_SYS_RESOURCE_H
56 #include <sys/resource.h>
57 #endif
58 
59 #include "gmp.h"
60 
61 /* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
62 #ifndef SIGSTKSZ
63 #define SIGSTKSZ  4096
64 #endif
65 
66 
67 #define TIME(t,func)                                                                      \
68   do { int __t0, __tmp;                                                                   \
69     __t0 = cputime ();                                                                    \
70     {func;}                                                                               \
71     __tmp = cputime () - __t0;                                                            \
72     (t) = __tmp;                                                                \
73   } while (0)
74 
75 /* GMP version 1.x compatibility.  */
76 #if ! (__GNU_MP_VERSION >= 2)
77 typedef MP_INT __mpz_struct;
78 typedef __mpz_struct mpz_t[1];
79 typedef __mpz_struct *mpz_ptr;
80 #define mpz_fdiv_q  mpz_div
81 #define mpz_fdiv_r  mpz_mod
82 #define mpz_tdiv_q_2exp       mpz_div_2exp
83 #define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
84 #endif
85 
86 /* GMP version 2.0 compatibility.  */
87 #if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
88 #define mpz_swap(a,b) \
89   do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
90 #endif
91 
92 jmp_buf errjmpbuf;
93 
94 enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
95              AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
96              LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM,
97              TIMING};
98 
99 /* Type for the expression tree.  */
100 struct expr
101 {
102   enum op_t op;
103   union
104   {
105     struct {struct expr *lhs, *rhs;} ops;
106     mpz_t val;
107   } operands;
108 };
109 
110 typedef struct expr *expr_t;
111 
112 void cleanup_and_exit (int);
113 
114 char *skipspace (char *);
115 void makeexp (expr_t *, enum op_t, expr_t, expr_t);
116 void free_expr (expr_t);
117 char *expr (char *, expr_t *);
118 char *term (char *, expr_t *);
119 char *power (char *, expr_t *);
120 char *factor (char *, expr_t *);
121 int match (char *, char *);
122 int matchp (char *, char *);
123 int cputime (void);
124 
125 void mpz_eval_expr (mpz_ptr, expr_t);
126 void mpz_eval_mod_expr (mpz_ptr, expr_t, mpz_ptr);
127 
128 char *error;
129 int flag_print = 1;
130 int print_timing = 0;
131 int flag_html = 0;
132 int flag_wml = 0;
133 int flag_splitup_output = 0;
134 char *newline = "";
135 gmp_randstate_t rstate;
136 
137 
138 
139 /* cputime() returns user CPU time measured in milliseconds.  */
140 #if ! HAVE_CPUTIME
141 #if HAVE_GETRUSAGE
142 int
cputime(void)143 cputime (void)
144 {
145   struct rusage rus;
146 
147   getrusage (0, &rus);
148   return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
149 }
150 #else
151 #if HAVE_CLOCK
152 int
cputime(void)153 cputime (void)
154 {
155   if (CLOCKS_PER_SEC < 100000)
156     return clock () * 1000 / CLOCKS_PER_SEC;
157   return clock () / (CLOCKS_PER_SEC / 1000);
158 }
159 #else
160 int
cputime(void)161 cputime (void)
162 {
163   return 0;
164 }
165 #endif
166 #endif
167 #endif
168 
169 
170 int
stack_downwards_helper(char * xp)171 stack_downwards_helper (char *xp)
172 {
173   char  y;
174   return &y < xp;
175 }
176 int
stack_downwards_p(void)177 stack_downwards_p (void)
178 {
179   char  x;
180   return stack_downwards_helper (&x);
181 }
182 
183 
184 void
setup_error_handler(void)185 setup_error_handler (void)
186 {
187 #if HAVE_SIGACTION
188   struct sigaction act;
189   act.sa_handler = cleanup_and_exit;
190   sigemptyset (&(act.sa_mask));
191 #define SIGNAL(sig)  sigaction (sig, &act, NULL)
192 #else
193   struct { int sa_flags; } act;
194 #define SIGNAL(sig)  signal (sig, cleanup_and_exit)
195 #endif
196   act.sa_flags = 0;
197 
198   /* Set up a stack for signal handling.  A typical cause of error is stack
199      overflow, and in such situation a signal can not be delivered on the
200      overflown stack.  */
201 #if HAVE_SIGALTSTACK
202   {
203     /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
204        systems have both. */
205 #if HAVE_STACK_T
206     stack_t s;
207 #else
208     struct sigaltstack s;
209 #endif
210     s.ss_sp = malloc (SIGSTKSZ);
211     s.ss_size = SIGSTKSZ;
212     s.ss_flags = 0;
213     if (sigaltstack (&s, NULL) != 0)
214       perror("sigaltstack");
215     act.sa_flags = SA_ONSTACK;
216   }
217 #else
218 #if HAVE_SIGSTACK
219   {
220     struct sigstack s;
221     s.ss_sp = malloc (SIGSTKSZ);
222     if (stack_downwards_p ())
223       s.ss_sp += SIGSTKSZ;
224     s.ss_onstack = 0;
225     if (sigstack (&s, NULL) != 0)
226       perror("sigstack");
227     act.sa_flags = SA_ONSTACK;
228   }
229 #else
230 #endif
231 #endif
232 
233 #ifdef LIMIT_RESOURCE_USAGE
234   {
235     struct rlimit limit;
236 
237     limit.rlim_cur = limit.rlim_max = 0;
238     setrlimit (RLIMIT_CORE, &limit);
239 
240     limit.rlim_cur = 3;
241     limit.rlim_max = 4;
242     setrlimit (RLIMIT_CPU, &limit);
243 
244     limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
245     setrlimit (RLIMIT_DATA, &limit);
246 
247     getrlimit (RLIMIT_STACK, &limit);
248     limit.rlim_cur = 4 * 1024 * 1024;
249     setrlimit (RLIMIT_STACK, &limit);
250 
251     SIGNAL (SIGXCPU);
252   }
253 #endif /* LIMIT_RESOURCE_USAGE */
254 
255   SIGNAL (SIGILL);
256   SIGNAL (SIGSEGV);
257 #ifdef SIGBUS /* not in mingw */
258   SIGNAL (SIGBUS);
259 #endif
260   SIGNAL (SIGFPE);
261   SIGNAL (SIGABRT);
262 }
263 
264 int
main(int argc,char ** argv)265 main (int argc, char **argv)
266 {
267   struct expr *e;
268   int i;
269   mpz_t r;
270   int errcode = 0;
271   char *str;
272   int base = 10;
273 
274   setup_error_handler ();
275 
276   gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
277 
278   {
279 #if HAVE_GETTIMEOFDAY
280     struct timeval tv;
281     gettimeofday (&tv, NULL);
282     gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
283 #else
284     time_t t;
285     time (&t);
286     gmp_randseed_ui (rstate, t);
287 #endif
288   }
289 
290   mpz_init (r);
291 
292   while (argc > 1 && argv[1][0] == '-')
293     {
294       char *arg = argv[1];
295 
296       if (arg[1] >= '0' && arg[1] <= '9')
297           break;
298 
299       if (arg[1] == 't')
300           print_timing = 1;
301       else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
302           {
303             base = atoi (arg + 2);
304             if (base < 2 || base > 62)
305               {
306                 fprintf (stderr, "error: invalid output base\n");
307                 exit (-1);
308               }
309           }
310       else if (arg[1] == 'b' && arg[2] == 0)
311           base = 2;
312       else if (arg[1] == 'x' && arg[2] == 0)
313           base = 16;
314       else if (arg[1] == 'X' && arg[2] == 0)
315           base = -16;
316       else if (arg[1] == 'o' && arg[2] == 0)
317           base = 8;
318       else if (arg[1] == 'd' && arg[2] == 0)
319           base = 10;
320       else if (arg[1] == 'v' && arg[2] == 0)
321           {
322             printf ("pexpr linked to gmp %s\n", __gmp_version);
323           }
324       else if (strcmp (arg, "-html") == 0)
325           {
326             flag_html = 1;
327             newline = "<br>";
328           }
329       else if (strcmp (arg, "-wml") == 0)
330           {
331             flag_wml = 1;
332             newline = "<br/>";
333           }
334       else if (strcmp (arg, "-split") == 0)
335           {
336             flag_splitup_output = 1;
337           }
338       else if (strcmp (arg, "-noprint") == 0)
339           {
340             flag_print = 0;
341           }
342       else
343           {
344             fprintf (stderr, "error: unknown option `%s'\n", arg);
345             exit (-1);
346           }
347       argv++;
348       argc--;
349     }
350 
351   for (i = 1; i < argc; i++)
352     {
353       int s;
354       int jmpval;
355 
356       /* Set up error handler for parsing expression.  */
357       jmpval = setjmp (errjmpbuf);
358       if (jmpval != 0)
359           {
360             fprintf (stderr, "error: %s%s\n", error, newline);
361             fprintf (stderr, "       %s%s\n", argv[i], newline);
362             if (! flag_html)
363               {
364                 /* ??? Dunno how to align expression position with arrow in
365                      HTML ??? */
366                 fprintf (stderr, "       ");
367                 for (s = jmpval - (long) argv[i]; --s >= 0; )
368                     putc (' ', stderr);
369                 fprintf (stderr, "^\n");
370               }
371 
372             errcode |= 1;
373             continue;
374           }
375 
376       str = expr (argv[i], &e);
377 
378       if (str[0] != 0)
379           {
380             fprintf (stderr,
381                        "error: garbage where end of expression expected%s\n",
382                        newline);
383             fprintf (stderr, "       %s%s\n", argv[i], newline);
384             if (! flag_html)
385               {
386                 /* ??? Dunno how to align expression position with arrow in
387                      HTML ??? */
388                 fprintf (stderr, "        ");
389                 for (s = str - argv[i]; --s; )
390                     putc (' ', stderr);
391                 fprintf (stderr, "^\n");
392               }
393 
394             errcode |= 1;
395             free_expr (e);
396             continue;
397           }
398 
399       /* Set up error handler for evaluating expression.  */
400       if (setjmp (errjmpbuf))
401           {
402             fprintf (stderr, "error: %s%s\n", error, newline);
403             fprintf (stderr, "       %s%s\n", argv[i], newline);
404             if (! flag_html)
405               {
406                 /* ??? Dunno how to align expression position with arrow in
407                      HTML ??? */
408                 fprintf (stderr, "       ");
409                 for (s = str - argv[i]; --s >= 0; )
410                     putc (' ', stderr);
411                 fprintf (stderr, "^\n");
412               }
413 
414             errcode |= 2;
415             continue;
416           }
417 
418       if (print_timing)
419           {
420             int t;
421             TIME (t, mpz_eval_expr (r, e));
422             printf ("computation took %d ms%s\n", t, newline);
423           }
424       else
425           mpz_eval_expr (r, e);
426 
427       if (flag_print)
428           {
429             size_t out_len;
430             char *tmp, *s;
431 
432             out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
433 #ifdef LIMIT_RESOURCE_USAGE
434             if (out_len > 100000)
435               {
436                 printf ("result is about %ld digits, not printing it%s\n",
437                           (long) out_len - 3, newline);
438                 exit (-2);
439               }
440 #endif
441             tmp = malloc (out_len);
442 
443             if (print_timing)
444               {
445                 int t;
446                 printf ("output conversion ");
447                 TIME (t, mpz_get_str (tmp, base, r));
448                 printf ("took %d ms%s\n", t, newline);
449               }
450             else
451               mpz_get_str (tmp, base, r);
452 
453             out_len = strlen (tmp);
454             if (flag_splitup_output)
455               {
456                 for (s = tmp; out_len > 80; s += 80)
457                     {
458                       fwrite (s, 1, 80, stdout);
459                       printf ("%s\n", newline);
460                       out_len -= 80;
461                     }
462 
463                 fwrite (s, 1, out_len, stdout);
464               }
465             else
466               {
467                 fwrite (tmp, 1, out_len, stdout);
468               }
469 
470             free (tmp);
471             printf ("%s\n", newline);
472           }
473       else
474           {
475             printf ("result is approximately %ld digits%s\n",
476                       (long) mpz_sizeinbase (r, base >= 0 ? base : -base),
477                       newline);
478           }
479 
480       free_expr (e);
481     }
482 
483   mpz_clear (r);
484 
485   exit (errcode);
486 }
487 
488 char *
expr(char * str,expr_t * e)489 expr (char *str, expr_t *e)
490 {
491   expr_t e2;
492 
493   str = skipspace (str);
494   if (str[0] == '+')
495     {
496       str = term (str + 1, e);
497     }
498   else if (str[0] == '-')
499     {
500       str = term (str + 1, e);
501       makeexp (e, NEG, *e, NULL);
502     }
503   else if (str[0] == '~')
504     {
505       str = term (str + 1, e);
506       makeexp (e, NOT, *e, NULL);
507     }
508   else
509     {
510       str = term (str, e);
511     }
512 
513   for (;;)
514     {
515       str = skipspace (str);
516       switch (str[0])
517           {
518           case 'p':
519             if (match ("plus", str))
520               {
521                 str = term (str + 4, &e2);
522                 makeexp (e, PLUS, *e, e2);
523               }
524             else
525               return str;
526             break;
527           case 'm':
528             if (match ("minus", str))
529               {
530                 str = term (str + 5, &e2);
531                 makeexp (e, MINUS, *e, e2);
532               }
533             else
534               return str;
535             break;
536           case '+':
537             str = term (str + 1, &e2);
538             makeexp (e, PLUS, *e, e2);
539             break;
540           case '-':
541             str = term (str + 1, &e2);
542             makeexp (e, MINUS, *e, e2);
543             break;
544           default:
545             return str;
546           }
547     }
548 }
549 
550 char *
term(char * str,expr_t * e)551 term (char *str, expr_t *e)
552 {
553   expr_t e2;
554 
555   str = power (str, e);
556   for (;;)
557     {
558       str = skipspace (str);
559       switch (str[0])
560           {
561           case 'm':
562             if (match ("mul", str))
563               {
564                 str = power (str + 3, &e2);
565                 makeexp (e, MULT, *e, e2);
566                 break;
567               }
568             if (match ("mod", str))
569               {
570                 str = power (str + 3, &e2);
571                 makeexp (e, MOD, *e, e2);
572                 break;
573               }
574             return str;
575           case 'd':
576             if (match ("div", str))
577               {
578                 str = power (str + 3, &e2);
579                 makeexp (e, DIV, *e, e2);
580                 break;
581               }
582             return str;
583           case 'r':
584             if (match ("rem", str))
585               {
586                 str = power (str + 3, &e2);
587                 makeexp (e, REM, *e, e2);
588                 break;
589               }
590             return str;
591           case 'i':
592             if (match ("invmod", str))
593               {
594                 str = power (str + 6, &e2);
595                 makeexp (e, REM, *e, e2);
596                 break;
597               }
598             return str;
599           case 't':
600             if (match ("times", str))
601               {
602                 str = power (str + 5, &e2);
603                 makeexp (e, MULT, *e, e2);
604                 break;
605               }
606             if (match ("thru", str))
607               {
608                 str = power (str + 4, &e2);
609                 makeexp (e, DIV, *e, e2);
610                 break;
611               }
612             if (match ("through", str))
613               {
614                 str = power (str + 7, &e2);
615                 makeexp (e, DIV, *e, e2);
616                 break;
617               }
618             return str;
619           case '*':
620             str = power (str + 1, &e2);
621             makeexp (e, MULT, *e, e2);
622             break;
623           case '/':
624             str = power (str + 1, &e2);
625             makeexp (e, DIV, *e, e2);
626             break;
627           case '%':
628             str = power (str + 1, &e2);
629             makeexp (e, MOD, *e, e2);
630             break;
631           default:
632             return str;
633           }
634     }
635 }
636 
637 char *
power(char * str,expr_t * e)638 power (char *str, expr_t *e)
639 {
640   expr_t e2;
641 
642   str = factor (str, e);
643   while (str[0] == '!')
644     {
645       str++;
646       makeexp (e, FAC, *e, NULL);
647     }
648   str = skipspace (str);
649   if (str[0] == '^')
650     {
651       str = power (str + 1, &e2);
652       makeexp (e, POW, *e, e2);
653     }
654   return str;
655 }
656 
657 int
match(char * s,char * str)658 match (char *s, char *str)
659 {
660   char *ostr = str;
661   int i;
662 
663   for (i = 0; s[i] != 0; i++)
664     {
665       if (str[i] != s[i])
666           return 0;
667     }
668   str = skipspace (str + i);
669   return str - ostr;
670 }
671 
672 int
matchp(char * s,char * str)673 matchp (char *s, char *str)
674 {
675   char *ostr = str;
676   int i;
677 
678   for (i = 0; s[i] != 0; i++)
679     {
680       if (str[i] != s[i])
681           return 0;
682     }
683   str = skipspace (str + i);
684   if (str[0] == '(')
685     return str - ostr + 1;
686   return 0;
687 }
688 
689 struct functions
690 {
691   char *spelling;
692   enum op_t op;
693   int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
694 };
695 
696 struct functions fns[] =
697 {
698   {"sqrt", SQRT, 1},
699 #if __GNU_MP_VERSION >= 2
700   {"root", ROOT, 2},
701   {"popc", POPCNT, 1},
702   {"hamdist", HAMDIST, 2},
703 #endif
704   {"gcd", GCD, 0},
705 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
706   {"lcm", LCM, 0},
707 #endif
708   {"and", AND, 0},
709   {"ior", IOR, 0},
710 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
711   {"xor", XOR, 0},
712 #endif
713   {"plus", PLUS, 0},
714   {"pow", POW, 2},
715   {"minus", MINUS, 2},
716   {"mul", MULT, 0},
717   {"div", DIV, 2},
718   {"mod", MOD, 2},
719   {"rem", REM, 2},
720 #if __GNU_MP_VERSION >= 2
721   {"invmod", INVMOD, 2},
722 #endif
723   {"log", LOG, 2},
724   {"log2", LOG2, 1},
725   {"F", FERMAT, 1},
726   {"M", MERSENNE, 1},
727   {"fib", FIBONACCI, 1},
728   {"Fib", FIBONACCI, 1},
729   {"random", RANDOM, 1},
730   {"nextprime", NEXTPRIME, 1},
731   {"binom", BINOM, 2},
732   {"binomial", BINOM, 2},
733   {"fac", FAC, 1},
734   {"fact", FAC, 1},
735   {"factorial", FAC, 1},
736   {"time", TIMING, 1},
737   {"", NOP, 0}
738 };
739 
740 char *
factor(char * str,expr_t * e)741 factor (char *str, expr_t *e)
742 {
743   expr_t e1, e2;
744 
745   str = skipspace (str);
746 
747   if (isalpha (str[0]))
748     {
749       int i;
750       int cnt;
751 
752       for (i = 0; fns[i].op != NOP; i++)
753           {
754             if (fns[i].arity == 1)
755               {
756                 cnt = matchp (fns[i].spelling, str);
757                 if (cnt != 0)
758                     {
759                       str = expr (str + cnt, &e1);
760                       str = skipspace (str);
761                       if (str[0] != ')')
762                         {
763                           error = "expected `)'";
764                           longjmp (errjmpbuf, (int) (long) str);
765                         }
766                       makeexp (e, fns[i].op, e1, NULL);
767                       return str + 1;
768                     }
769               }
770           }
771 
772       for (i = 0; fns[i].op != NOP; i++)
773           {
774             if (fns[i].arity != 1)
775               {
776                 cnt = matchp (fns[i].spelling, str);
777                 if (cnt != 0)
778                     {
779                       str = expr (str + cnt, &e1);
780                       str = skipspace (str);
781 
782                       if (str[0] != ',')
783                         {
784                           error = "expected `,' and another operand";
785                           longjmp (errjmpbuf, (int) (long) str);
786                         }
787 
788                       str = skipspace (str + 1);
789                       str = expr (str, &e2);
790                       str = skipspace (str);
791 
792                       if (fns[i].arity == 0)
793                         {
794                           while (str[0] == ',')
795                               {
796                                 makeexp (&e1, fns[i].op, e1, e2);
797                                 str = skipspace (str + 1);
798                                 str = expr (str, &e2);
799                                 str = skipspace (str);
800                               }
801                         }
802 
803                       if (str[0] != ')')
804                         {
805                           error = "expected `)'";
806                           longjmp (errjmpbuf, (int) (long) str);
807                         }
808 
809                       makeexp (e, fns[i].op, e1, e2);
810                       return str + 1;
811                     }
812               }
813           }
814     }
815 
816   if (str[0] == '(')
817     {
818       str = expr (str + 1, e);
819       str = skipspace (str);
820       if (str[0] != ')')
821           {
822             error = "expected `)'";
823             longjmp (errjmpbuf, (int) (long) str);
824           }
825       str++;
826     }
827   else if (str[0] >= '0' && str[0] <= '9')
828     {
829       expr_t res;
830       char *s, *sc;
831 
832       res = malloc (sizeof (struct expr));
833       res -> op = LIT;
834       mpz_init (res->operands.val);
835 
836       s = str;
837       while (isalnum (str[0]))
838           str++;
839       sc = malloc (str - s + 1);
840       memcpy (sc, s, str - s);
841       sc[str - s] = 0;
842 
843       mpz_set_str (res->operands.val, sc, 0);
844       *e = res;
845       free (sc);
846     }
847   else
848     {
849       error = "operand expected";
850       longjmp (errjmpbuf, (int) (long) str);
851     }
852   return str;
853 }
854 
855 char *
skipspace(char * str)856 skipspace (char *str)
857 {
858   while (str[0] == ' ')
859     str++;
860   return str;
861 }
862 
863 /* Make a new expression with operation OP and right hand side
864    RHS and left hand side lhs.  Put the result in R.  */
865 void
makeexp(expr_t * r,enum op_t op,expr_t lhs,expr_t rhs)866 makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
867 {
868   expr_t res;
869   res = malloc (sizeof (struct expr));
870   res -> op = op;
871   res -> operands.ops.lhs = lhs;
872   res -> operands.ops.rhs = rhs;
873   *r = res;
874   return;
875 }
876 
877 /* Free the memory used by expression E.  */
878 void
free_expr(expr_t e)879 free_expr (expr_t e)
880 {
881   if (e->op != LIT)
882     {
883       free_expr (e->operands.ops.lhs);
884       if (e->operands.ops.rhs != NULL)
885           free_expr (e->operands.ops.rhs);
886     }
887   else
888     {
889       mpz_clear (e->operands.val);
890     }
891 }
892 
893 /* Evaluate the expression E and put the result in R.  */
894 void
mpz_eval_expr(mpz_ptr r,expr_t e)895 mpz_eval_expr (mpz_ptr r, expr_t e)
896 {
897   mpz_t lhs, rhs;
898 
899   switch (e->op)
900     {
901     case LIT:
902       mpz_set (r, e->operands.val);
903       return;
904     case PLUS:
905       mpz_init (lhs); mpz_init (rhs);
906       mpz_eval_expr (lhs, e->operands.ops.lhs);
907       mpz_eval_expr (rhs, e->operands.ops.rhs);
908       mpz_add (r, lhs, rhs);
909       mpz_clear (lhs); mpz_clear (rhs);
910       return;
911     case MINUS:
912       mpz_init (lhs); mpz_init (rhs);
913       mpz_eval_expr (lhs, e->operands.ops.lhs);
914       mpz_eval_expr (rhs, e->operands.ops.rhs);
915       mpz_sub (r, lhs, rhs);
916       mpz_clear (lhs); mpz_clear (rhs);
917       return;
918     case MULT:
919       mpz_init (lhs); mpz_init (rhs);
920       mpz_eval_expr (lhs, e->operands.ops.lhs);
921       mpz_eval_expr (rhs, e->operands.ops.rhs);
922       mpz_mul (r, lhs, rhs);
923       mpz_clear (lhs); mpz_clear (rhs);
924       return;
925     case DIV:
926       mpz_init (lhs); mpz_init (rhs);
927       mpz_eval_expr (lhs, e->operands.ops.lhs);
928       mpz_eval_expr (rhs, e->operands.ops.rhs);
929       mpz_fdiv_q (r, lhs, rhs);
930       mpz_clear (lhs); mpz_clear (rhs);
931       return;
932     case MOD:
933       mpz_init (rhs);
934       mpz_eval_expr (rhs, e->operands.ops.rhs);
935       mpz_abs (rhs, rhs);
936       mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
937       mpz_clear (rhs);
938       return;
939     case REM:
940       /* Check if lhs operand is POW expression and optimize for that case.  */
941       if (e->operands.ops.lhs->op == POW)
942           {
943             mpz_t powlhs, powrhs;
944             mpz_init (powlhs);
945             mpz_init (powrhs);
946             mpz_init (rhs);
947             mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
948             mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
949             mpz_eval_expr (rhs, e->operands.ops.rhs);
950             mpz_powm (r, powlhs, powrhs, rhs);
951             if (mpz_cmp_si (rhs, 0L) < 0)
952               mpz_neg (r, r);
953             mpz_clear (powlhs);
954             mpz_clear (powrhs);
955             mpz_clear (rhs);
956             return;
957           }
958 
959       mpz_init (lhs); mpz_init (rhs);
960       mpz_eval_expr (lhs, e->operands.ops.lhs);
961       mpz_eval_expr (rhs, e->operands.ops.rhs);
962       mpz_fdiv_r (r, lhs, rhs);
963       mpz_clear (lhs); mpz_clear (rhs);
964       return;
965 #if __GNU_MP_VERSION >= 2
966     case INVMOD:
967       mpz_init (lhs); mpz_init (rhs);
968       mpz_eval_expr (lhs, e->operands.ops.lhs);
969       mpz_eval_expr (rhs, e->operands.ops.rhs);
970       mpz_invert (r, lhs, rhs);
971       mpz_clear (lhs); mpz_clear (rhs);
972       return;
973 #endif
974     case POW:
975       mpz_init (lhs); mpz_init (rhs);
976       mpz_eval_expr (lhs, e->operands.ops.lhs);
977       if (mpz_cmpabs_ui (lhs, 1) <= 0)
978           {
979             /* For 0^rhs and 1^rhs, we just need to verify that
980                rhs is well-defined.  For (-1)^rhs we need to
981                determine (rhs mod 2).  For simplicity, compute
982                (rhs mod 2) for all three cases.  */
983             expr_t two, et;
984             two = malloc (sizeof (struct expr));
985             two -> op = LIT;
986             mpz_init_set_ui (two->operands.val, 2L);
987             makeexp (&et, MOD, e->operands.ops.rhs, two);
988             e->operands.ops.rhs = et;
989           }
990 
991       mpz_eval_expr (rhs, e->operands.ops.rhs);
992       if (mpz_cmp_si (rhs, 0L) == 0)
993           /* x^0 is 1 */
994           mpz_set_ui (r, 1L);
995       else if (mpz_cmp_si (lhs, 0L) == 0)
996           /* 0^y (where y != 0) is 0 */
997           mpz_set_ui (r, 0L);
998       else if (mpz_cmp_ui (lhs, 1L) == 0)
999           /* 1^y is 1 */
1000           mpz_set_ui (r, 1L);
1001       else if (mpz_cmp_si (lhs, -1L) == 0)
1002           /* (-1)^y just depends on whether y is even or odd */
1003           mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
1004       else if (mpz_cmp_si (rhs, 0L) < 0)
1005           /* x^(-n) is 0 */
1006           mpz_set_ui (r, 0L);
1007       else
1008           {
1009             unsigned long int cnt;
1010             unsigned long int y;
1011             /* error if exponent does not fit into an unsigned long int.  */
1012             if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1013               goto pow_err;
1014 
1015             y = mpz_get_ui (rhs);
1016             /* x^y == (x/(2^c))^y * 2^(c*y) */
1017 #if __GNU_MP_VERSION >= 2
1018             cnt = mpz_scan1 (lhs, 0);
1019 #else
1020             cnt = 0;
1021 #endif
1022             if (cnt != 0)
1023               {
1024                 if (y * cnt / cnt != y)
1025                     goto pow_err;
1026                 mpz_tdiv_q_2exp (lhs, lhs, cnt);
1027                 mpz_pow_ui (r, lhs, y);
1028                 mpz_mul_2exp (r, r, y * cnt);
1029               }
1030             else
1031               mpz_pow_ui (r, lhs, y);
1032           }
1033       mpz_clear (lhs); mpz_clear (rhs);
1034       return;
1035     pow_err:
1036       error = "result of `pow' operator too large";
1037       mpz_clear (lhs); mpz_clear (rhs);
1038       longjmp (errjmpbuf, 1);
1039     case GCD:
1040       mpz_init (lhs); mpz_init (rhs);
1041       mpz_eval_expr (lhs, e->operands.ops.lhs);
1042       mpz_eval_expr (rhs, e->operands.ops.rhs);
1043       mpz_gcd (r, lhs, rhs);
1044       mpz_clear (lhs); mpz_clear (rhs);
1045       return;
1046 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1047     case LCM:
1048       mpz_init (lhs); mpz_init (rhs);
1049       mpz_eval_expr (lhs, e->operands.ops.lhs);
1050       mpz_eval_expr (rhs, e->operands.ops.rhs);
1051       mpz_lcm (r, lhs, rhs);
1052       mpz_clear (lhs); mpz_clear (rhs);
1053       return;
1054 #endif
1055     case AND:
1056       mpz_init (lhs); mpz_init (rhs);
1057       mpz_eval_expr (lhs, e->operands.ops.lhs);
1058       mpz_eval_expr (rhs, e->operands.ops.rhs);
1059       mpz_and (r, lhs, rhs);
1060       mpz_clear (lhs); mpz_clear (rhs);
1061       return;
1062     case IOR:
1063       mpz_init (lhs); mpz_init (rhs);
1064       mpz_eval_expr (lhs, e->operands.ops.lhs);
1065       mpz_eval_expr (rhs, e->operands.ops.rhs);
1066       mpz_ior (r, lhs, rhs);
1067       mpz_clear (lhs); mpz_clear (rhs);
1068       return;
1069 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1070     case XOR:
1071       mpz_init (lhs); mpz_init (rhs);
1072       mpz_eval_expr (lhs, e->operands.ops.lhs);
1073       mpz_eval_expr (rhs, e->operands.ops.rhs);
1074       mpz_xor (r, lhs, rhs);
1075       mpz_clear (lhs); mpz_clear (rhs);
1076       return;
1077 #endif
1078     case NEG:
1079       mpz_eval_expr (r, e->operands.ops.lhs);
1080       mpz_neg (r, r);
1081       return;
1082     case NOT:
1083       mpz_eval_expr (r, e->operands.ops.lhs);
1084       mpz_com (r, r);
1085       return;
1086     case SQRT:
1087       mpz_init (lhs);
1088       mpz_eval_expr (lhs, e->operands.ops.lhs);
1089       if (mpz_sgn (lhs) < 0)
1090           {
1091             error = "cannot take square root of negative numbers";
1092             mpz_clear (lhs);
1093             longjmp (errjmpbuf, 1);
1094           }
1095       mpz_sqrt (r, lhs);
1096       return;
1097 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1098     case ROOT:
1099       mpz_init (lhs); mpz_init (rhs);
1100       mpz_eval_expr (lhs, e->operands.ops.lhs);
1101       mpz_eval_expr (rhs, e->operands.ops.rhs);
1102       if (mpz_sgn (rhs) <= 0)
1103           {
1104             error = "cannot take non-positive root orders";
1105             mpz_clear (lhs); mpz_clear (rhs);
1106             longjmp (errjmpbuf, 1);
1107           }
1108       if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1109           {
1110             error = "cannot take even root orders of negative numbers";
1111             mpz_clear (lhs); mpz_clear (rhs);
1112             longjmp (errjmpbuf, 1);
1113           }
1114 
1115       {
1116           unsigned long int nth = mpz_get_ui (rhs);
1117           if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1118             {
1119               /* If we are asked to take an awfully large root order, cheat and
1120                  ask for the largest order we can pass to mpz_root.  This saves
1121                  some error prone special cases.  */
1122               nth = ~(unsigned long int) 0;
1123             }
1124           mpz_root (r, lhs, nth);
1125       }
1126       mpz_clear (lhs); mpz_clear (rhs);
1127       return;
1128 #endif
1129     case FAC:
1130       mpz_eval_expr (r, e->operands.ops.lhs);
1131       if (mpz_size (r) > 1)
1132           {
1133             error = "result of `!' operator too large";
1134             longjmp (errjmpbuf, 1);
1135           }
1136       mpz_fac_ui (r, mpz_get_ui (r));
1137       return;
1138 #if __GNU_MP_VERSION >= 2
1139     case POPCNT:
1140       mpz_eval_expr (r, e->operands.ops.lhs);
1141       { long int cnt;
1142           cnt = mpz_popcount (r);
1143           mpz_set_si (r, cnt);
1144       }
1145       return;
1146     case HAMDIST:
1147       { long int cnt;
1148           mpz_init (lhs); mpz_init (rhs);
1149           mpz_eval_expr (lhs, e->operands.ops.lhs);
1150           mpz_eval_expr (rhs, e->operands.ops.rhs);
1151           cnt = mpz_hamdist (lhs, rhs);
1152           mpz_clear (lhs); mpz_clear (rhs);
1153           mpz_set_si (r, cnt);
1154       }
1155       return;
1156 #endif
1157     case LOG2:
1158       mpz_eval_expr (r, e->operands.ops.lhs);
1159       { unsigned long int cnt;
1160           if (mpz_sgn (r) <= 0)
1161             {
1162               error = "logarithm of non-positive number";
1163               longjmp (errjmpbuf, 1);
1164             }
1165           cnt = mpz_sizeinbase (r, 2);
1166           mpz_set_ui (r, cnt - 1);
1167       }
1168       return;
1169     case LOG:
1170       { unsigned long int cnt;
1171           mpz_init (lhs); mpz_init (rhs);
1172           mpz_eval_expr (lhs, e->operands.ops.lhs);
1173           mpz_eval_expr (rhs, e->operands.ops.rhs);
1174           if (mpz_sgn (lhs) <= 0)
1175             {
1176               error = "logarithm of non-positive number";
1177               mpz_clear (lhs); mpz_clear (rhs);
1178               longjmp (errjmpbuf, 1);
1179             }
1180           if (mpz_cmp_ui (rhs, 256) >= 0)
1181             {
1182               error = "logarithm base too large";
1183               mpz_clear (lhs); mpz_clear (rhs);
1184               longjmp (errjmpbuf, 1);
1185             }
1186           cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1187           mpz_set_ui (r, cnt - 1);
1188           mpz_clear (lhs); mpz_clear (rhs);
1189       }
1190       return;
1191     case FERMAT:
1192       {
1193           unsigned long int t;
1194           mpz_init (lhs);
1195           mpz_eval_expr (lhs, e->operands.ops.lhs);
1196           t = (unsigned long int) 1 << mpz_get_ui (lhs);
1197           if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1198             {
1199               error = "too large Mersenne number index";
1200               mpz_clear (lhs);
1201               longjmp (errjmpbuf, 1);
1202             }
1203           mpz_set_ui (r, 1);
1204           mpz_mul_2exp (r, r, t);
1205           mpz_add_ui (r, r, 1);
1206           mpz_clear (lhs);
1207       }
1208       return;
1209     case MERSENNE:
1210       mpz_init (lhs);
1211       mpz_eval_expr (lhs, e->operands.ops.lhs);
1212       if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1213           {
1214             error = "too large Mersenne number index";
1215             mpz_clear (lhs);
1216             longjmp (errjmpbuf, 1);
1217           }
1218       mpz_set_ui (r, 1);
1219       mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1220       mpz_sub_ui (r, r, 1);
1221       mpz_clear (lhs);
1222       return;
1223     case FIBONACCI:
1224       { mpz_t t;
1225           unsigned long int n, i;
1226           mpz_init (lhs);
1227           mpz_eval_expr (lhs, e->operands.ops.lhs);
1228           if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1229             {
1230               error = "Fibonacci index out of range";
1231               mpz_clear (lhs);
1232               longjmp (errjmpbuf, 1);
1233             }
1234           n = mpz_get_ui (lhs);
1235           mpz_clear (lhs);
1236 
1237 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1238           mpz_fib_ui (r, n);
1239 #else
1240           mpz_init_set_ui (t, 1);
1241           mpz_set_ui (r, 1);
1242 
1243           if (n <= 2)
1244             mpz_set_ui (r, 1);
1245           else
1246             {
1247               for (i = 3; i <= n; i++)
1248                 {
1249                     mpz_add (t, t, r);
1250                     mpz_swap (t, r);
1251                 }
1252             }
1253           mpz_clear (t);
1254 #endif
1255       }
1256       return;
1257     case RANDOM:
1258       {
1259           unsigned long int n;
1260           mpz_init (lhs);
1261           mpz_eval_expr (lhs, e->operands.ops.lhs);
1262           if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1263             {
1264               error = "random number size out of range";
1265               mpz_clear (lhs);
1266               longjmp (errjmpbuf, 1);
1267             }
1268           n = mpz_get_ui (lhs);
1269           mpz_clear (lhs);
1270           mpz_urandomb (r, rstate, n);
1271       }
1272       return;
1273     case NEXTPRIME:
1274       {
1275           mpz_eval_expr (r, e->operands.ops.lhs);
1276           mpz_nextprime (r, r);
1277       }
1278       return;
1279     case BINOM:
1280       mpz_init (lhs); mpz_init (rhs);
1281       mpz_eval_expr (lhs, e->operands.ops.lhs);
1282       mpz_eval_expr (rhs, e->operands.ops.rhs);
1283       {
1284           unsigned long int k;
1285           if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1286             {
1287               error = "k too large in (n over k) expression";
1288               mpz_clear (lhs); mpz_clear (rhs);
1289               longjmp (errjmpbuf, 1);
1290             }
1291           k = mpz_get_ui (rhs);
1292           mpz_bin_ui (r, lhs, k);
1293       }
1294       mpz_clear (lhs); mpz_clear (rhs);
1295       return;
1296     case TIMING:
1297       {
1298           int t0;
1299           t0 = cputime ();
1300           mpz_eval_expr (r, e->operands.ops.lhs);
1301           printf ("time: %d\n", cputime () - t0);
1302       }
1303       return;
1304     default:
1305       abort ();
1306     }
1307 }
1308 
1309 /* Evaluate the expression E modulo MOD and put the result in R.  */
1310 void
mpz_eval_mod_expr(mpz_ptr r,expr_t e,mpz_ptr mod)1311 mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1312 {
1313   mpz_t lhs, rhs;
1314 
1315   switch (e->op)
1316     {
1317       case POW:
1318           mpz_init (lhs); mpz_init (rhs);
1319           mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1320           mpz_eval_expr (rhs, e->operands.ops.rhs);
1321           mpz_powm (r, lhs, rhs, mod);
1322           mpz_clear (lhs); mpz_clear (rhs);
1323           return;
1324       case PLUS:
1325           mpz_init (lhs); mpz_init (rhs);
1326           mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1327           mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1328           mpz_add (r, lhs, rhs);
1329           if (mpz_cmp_si (r, 0L) < 0)
1330             mpz_add (r, r, mod);
1331           else if (mpz_cmp (r, mod) >= 0)
1332             mpz_sub (r, r, mod);
1333           mpz_clear (lhs); mpz_clear (rhs);
1334           return;
1335       case MINUS:
1336           mpz_init (lhs); mpz_init (rhs);
1337           mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1338           mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1339           mpz_sub (r, lhs, rhs);
1340           if (mpz_cmp_si (r, 0L) < 0)
1341             mpz_add (r, r, mod);
1342           else if (mpz_cmp (r, mod) >= 0)
1343             mpz_sub (r, r, mod);
1344           mpz_clear (lhs); mpz_clear (rhs);
1345           return;
1346       case MULT:
1347           mpz_init (lhs); mpz_init (rhs);
1348           mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1349           mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1350           mpz_mul (r, lhs, rhs);
1351           mpz_mod (r, r, mod);
1352           mpz_clear (lhs); mpz_clear (rhs);
1353           return;
1354       default:
1355           mpz_init (lhs);
1356           mpz_eval_expr (lhs, e);
1357           mpz_mod (r, lhs, mod);
1358           mpz_clear (lhs);
1359           return;
1360     }
1361 }
1362 
1363 void
cleanup_and_exit(int sig)1364 cleanup_and_exit (int sig)
1365 {
1366   switch (sig) {
1367 #ifdef LIMIT_RESOURCE_USAGE
1368   case SIGXCPU:
1369     printf ("expression took too long to evaluate%s\n", newline);
1370     break;
1371 #endif
1372   case SIGFPE:
1373     printf ("divide by zero%s\n", newline);
1374     break;
1375   default:
1376     printf ("expression required too much memory to evaluate%s\n", newline);
1377     break;
1378   }
1379   exit (-2);
1380 }
1381