mpmath_special_functions_test_generator.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #!/usr/bin/env python3
  2. import mpmath as mp
  3. import mpmath_riccati_bessel as mrb
  4. import mpmath_input_arguments as mia
  5. import os.path
  6. class TestData:
  7. def __init__(self, list_to_parse, filetype):
  8. self.filetype = filetype
  9. if self.filetype == 'c++':
  10. self.cpp_parse(list_to_parse)
  11. else:
  12. raise NotImplementedError("Only C++ files *.hpp parsing was implemented")
  13. def cpp_parse(self, list_to_parse):
  14. self.comment = list_to_parse[0]
  15. if self.comment[:2] != '//': raise ValueError('Not a comment')
  16. self.typeline = list_to_parse[1]
  17. if 'std::vector' not in self.typeline: raise ValueError('Unexpected C++ container')
  18. self.testname = list_to_parse[2]
  19. self.opening = list_to_parse[3]
  20. if self.opening != '= {': raise ValueError('For C++ we expect opeing with = {');
  21. self.ending = list_to_parse[-1]
  22. if self.ending != '};': raise ValueError('For C++ we expect closing };')
  23. self.evaluated_data = list_to_parse[4:-1]
  24. def get_string(self):
  25. out_sting = self.comment + '\n' + self.typeline + '\n' + self.testname + '\n' + self.opening + '\n'
  26. for result in self.evaluated_data:
  27. out_sting += result + '\n'
  28. out_sting += self.ending + '\n'
  29. return out_sting
  30. class UpdateSpecialFunctionsEvaluations:
  31. def __init__(self, filename='default_out.hpp', complex_arguments=[],
  32. output_dps=16, max_num_elements_of_nlist=51):
  33. self.evaluated_data = []
  34. self.test_setup = []
  35. self.filename = filename
  36. self.read_evaluated_data()
  37. self.complex_arguments = complex_arguments
  38. self.output_dps = output_dps
  39. self.max_num_elements_of_nlist = max_num_elements_of_nlist
  40. def read_evaluated_data(self):
  41. self.filetype = 'undefined'
  42. if self.filename.endswith('.hpp'):
  43. self.filetype = 'c++'
  44. if self.filename.endswith('.f90'):
  45. self.filetype = 'fortran'
  46. if not os.path.exists(self.filename):
  47. print("WARNING! Found no data file:", self.filename)
  48. return
  49. with open(self.filename, 'r') as in_file:
  50. content = in_file.readlines()
  51. content = [x.strip() for x in content]
  52. while '' in content:
  53. record_end_index = content.index('')
  54. new_record = content[:record_end_index]
  55. content = content[record_end_index + 1:]
  56. self.add_record(new_record)
  57. self.add_record(content)
  58. def add_record(self, new_record):
  59. if len(new_record) == 0: return
  60. if len(new_record) < 6: raise ValueError('Not enough lines in record:', new_record)
  61. self.evaluated_data.append(TestData(new_record, self.filetype))
  62. def get_file_content(self):
  63. self.evaluated_data.sort(key=lambda x: x.testname) # , reverse=True)
  64. out_string = ''
  65. for record in self.evaluated_data:
  66. out_string += record.get_string() + '\n'
  67. return out_string[:-1]
  68. def remove(self, testname):
  69. for i, result in enumerate(self.evaluated_data):
  70. if result.testname == testname:
  71. del self.evaluated_data[i]
  72. def get_n_list(self, z, max_number_of_elements=10):
  73. nmax = mrb.LeRu_cutoff(z)
  74. factor = nmax ** (1 / (max_number_of_elements - 2))
  75. n_list = [int(factor ** i) for i in range(max_number_of_elements - 1)]
  76. n_list.append(0)
  77. n_set = set(n_list)
  78. return sorted(n_set)
  79. def compose_result_string(self, mpf_x, mpf_m, n, mpf_value, output_dps):
  80. return ('{'+
  81. mp.nstr(mpf_x, output_dps * 2) + ',{' +
  82. mp.nstr(mpf_m.real, output_dps * 2) + ',' +
  83. mp.nstr(mpf_m.imag, output_dps * 2) + '},' +
  84. str(n) + ',{' +
  85. mp.nstr(mpf_value.real, output_dps) + ',' +
  86. mp.nstr(mpf_value.imag, output_dps) + '},' +
  87. mp.nstr(mp.fabs(mpf_value.real * 10 ** -output_dps), 2) + ',' +
  88. mp.nstr(mp.fabs(mpf_value.imag * 10 ** -output_dps), 2) +
  89. '},')
  90. def get_test_data_nlist(self, z_record, output_dps, n, func):
  91. isNeedMoreDPS = False
  92. x = str(z_record[0])
  93. mr = str(z_record[1][0])
  94. mi = str(z_record[1][1])
  95. z_str = ''
  96. try:
  97. mpf_x = mp.mpf(x)
  98. mpf_m = mp.mpc(mr, mi)
  99. z = mpf_x*mpf_m
  100. if self.is_only_x: z = mp.mpf(x)
  101. if self.is_xm:
  102. mpf_value = func(n, mpf_x, mpf_m)
  103. else:
  104. mpf_value = func(n, z)
  105. z_str = self.compose_result_string(mpf_x, mpf_m, n, mpf_value, output_dps)
  106. if mp.nstr(mpf_value.real, output_dps) == '0.0' \
  107. or mp.nstr(mpf_value.imag, output_dps) == '0.0':
  108. isNeedMoreDPS = True
  109. except:
  110. isNeedMoreDPS = True
  111. return z_str, isNeedMoreDPS
  112. def get_test_data(self, Du_test, output_dps, max_num_elements_of_n_list, func, funcname):
  113. output_list = ['// x, complex(m), n, complex(f(n,z)), abs_err_real, abs_err_imag',
  114. 'std::vector< std::tuple< nmie::FloatType, std::complex<nmie::FloatType>, int, std::complex<nmie::FloatType>, nmie::FloatType, nmie::FloatType > >',
  115. str(funcname) + '_test_' + str(output_dps) + 'digits', '= {']
  116. for z_record in Du_test:
  117. x = str(z_record[0])
  118. mr = str(z_record[1][0])
  119. mi = str(z_record[1][1])
  120. mp.mp.dps = 20
  121. z = mp.mpf(x) * mp.mpc(mr, mi)
  122. n_list = self.get_n_list(z, max_num_elements_of_n_list)
  123. if z_record[4] == 'Yang': n_list = [0, 1, 30, 50, 60, 70, 75, 80, 85, 90, 99, 116, 130]
  124. print(z, n_list)
  125. failed_evaluations = 0
  126. for n in n_list:
  127. mp.mp.dps = output_dps
  128. old_z_string, isNeedMoreDPS = self.get_test_data_nlist(z_record, output_dps, n, func, )
  129. mp.mp.dps = int(output_dps*1.41)
  130. new_z_string, isNeedMoreDPS = self.get_test_data_nlist(z_record, output_dps, n, func)
  131. while old_z_string != new_z_string \
  132. or isNeedMoreDPS:
  133. new_dps = int(mp.mp.dps * 1.41)
  134. if new_dps > 300: break
  135. mp.mp.dps = new_dps
  136. print("New dps = ", mp.mp.dps, 'n =', n, ' (max ', n_list[-1], ') for z =', z, ' ', end='')
  137. old_z_string = new_z_string
  138. new_z_string, isNeedMoreDPS = self.get_test_data_nlist(z_record, output_dps, n, func)
  139. if new_z_string != '':
  140. output_list.append(new_z_string)
  141. else:
  142. failed_evaluations += 1
  143. # break
  144. result_str = "All done!"
  145. if failed_evaluations > 0: result_str = " FAILED!"
  146. print("\n", result_str, "Failed evaluations ", failed_evaluations, ' of ', len(n_list))
  147. output_list.append('};')
  148. return output_list
  149. def run_test(self, func, funcname, is_only_x=False, is_xm=False):
  150. self.is_only_x = is_only_x
  151. self.is_xm = is_xm
  152. self.remove_argument_duplicates()
  153. out_list_result = self.get_test_data(self.complex_arguments, self.output_dps,
  154. self.max_num_elements_of_nlist,
  155. func, funcname)
  156. testname = str(funcname) + '_test_' + str(self.output_dps) + 'digits'
  157. self.remove(testname)
  158. self.add_record(out_list_result)
  159. def remove_argument_duplicates(self):
  160. print("Arguments in input: ", len(self.complex_arguments))
  161. mp.mp.dps = 20
  162. self.complex_arguments.sort()
  163. filtered_list = []
  164. filtered_list.append(self.complex_arguments[0])
  165. for i in range(1, len(self.complex_arguments)):
  166. # if x and m are the same: continue
  167. if (filtered_list[-1][0] == self.complex_arguments[i][0] and
  168. filtered_list[-1][1] == self.complex_arguments[i][1]):
  169. continue
  170. # argument list is sorted, so when only x is needed
  171. # keep the record with the largest m
  172. if (self.is_only_x
  173. and filtered_list[-1][0] == self.complex_arguments[i][0]):
  174. # continue
  175. del filtered_list[-1]
  176. filtered_list.append(self.complex_arguments[i])
  177. self.complex_arguments = filtered_list
  178. # print(self.complex_arguments)
  179. print("Arguments after filtering: ", len(self.complex_arguments))
  180. # exit(0)
  181. def main():
  182. sf_evals = UpdateSpecialFunctionsEvaluations(filename='test_spec_functions_data.hpp',
  183. complex_arguments=mia.complex_arguments,
  184. output_dps=30, max_num_elements_of_nlist=51)
  185. # output_dps=7, max_num_elements_of_nlist=51)
  186. # output_dps=5, max_num_elements_of_nlist=3)
  187. # sf_evals.run_test(mrb.D1, 'D1')
  188. # sf_evals.run_test(mrb.D2, 'D2')
  189. # sf_evals.run_test(mrb.D3, 'D3')
  190. # sf_evals.run_test(mrb.psi, 'psi', is_only_x=True)
  191. # sf_evals.run_test(mrb.xi, 'xi', is_only_x=True)
  192. # # In literature Zeta or Ksi denote the Riccati-Bessel function of third kind.
  193. # sf_evals.run_test(mrb.ksi, 'zeta', is_only_x=True)
  194. # sf_evals.run_test(mrb.an, 'an', is_xm=True)
  195. # sf_evals.run_test(mrb.bn, 'bn', is_xm=True)
  196. # sf_evals.run_test(mrb.psi, 'psi')
  197. # sf_evals.run_test(mrb.psi_div_ksi, 'psi_div_ksi')
  198. # sf_evals.run_test(mrb.psi_mul_ksi, 'psi_mul_zeta', is_only_x=True)
  199. # sf_evals.run_test(mrb.psi_div_xi, 'psi_div_xi')
  200. with open(sf_evals.filename, 'w') as out_file:
  201. out_file.write(sf_evals.get_file_content())
  202. for record in mia.complex_arguments:
  203. mp.mp.dps = 20
  204. output_dps = 16
  205. x = mp.mpf(str(record[0]))
  206. mr = str(record[1][0])
  207. mi = str(record[1][1])
  208. m = mp.mpc(mr, mi)
  209. Qext_ref = record[2]
  210. Qsca_ref = record[3]
  211. test_case = record[4]
  212. nmax = int(x + 4.05*x**(1./3.) + 2)+2+28
  213. print(f"\n ===== test case: {test_case} =====", flush=True)
  214. print(f"x={x}, m={m}, N={nmax} \nQsca_ref = {Qsca_ref} \tQext_ref = {Qext_ref}", flush=True)
  215. Qext_mp = mrb.Qext(x,m,nmax, output_dps)
  216. Qsca_mp = mrb.Qsca(x,m,nmax, output_dps)
  217. print(f"Qsca_mp = {mp.nstr(Qsca_mp[-1],output_dps)} \tQext_mp = {mp.nstr(Qext_mp[-1],output_dps)}", flush=True)
  218. print(mp.nstr(Qsca_mp,output_dps))
  219. print(mp.nstr(Qext_mp,output_dps))
  220. # n=1
  221. # print(f'n={n}, x={x}, m={m}\nbn[{n}]={mp.nstr(mrb.bn(n,x,m), output_dps)}')
  222. main()