commit 2e9d87b259dcacf7c74ce8ca85554b05d08533d0
parent 0cc78918862ea2153868166130c2c4bf51ee5ae4
Author: Anders Damsgaard <anders@adamsgaard.dk>
Date:   Wed, 24 Nov 2021 15:38:45 +0100
linregress: add option to save plot to disk
Diffstat:
1 file changed, 20 insertions(+), 4 deletions(-)
diff --git a/linregress b/linregress
@@ -3,8 +3,8 @@ import sys, getopt
 import numpy
 import scipy.stats
 
-def usage(arg0):
-	print("usage: {} [-h] [-s significance_level]".format(sys.argv[0]))
+def usage():
+	print("usage: {} [-h] [-o outfile] [-s significance_level]".format(sys.argv[0]))
 
 def read2dstdin():
 	input = numpy.loadtxt(sys.stdin)
@@ -25,7 +25,6 @@ def reportsignif(p, significlvl):
 		print("greater or equal than the significance level ({:g}),".format(significlvl))
 		print("which means that the null hypothesis of zero slope CANNOT be rejected.\n")
 
-
 def reportcorr(r):
 	print("The correlation coefficient (r-value) denotes ", end="")
 	if (abs(r) < 0.01):
@@ -43,10 +42,23 @@ def reportcorr(r):
 		print(" positive" if r > 0.0 else " negative")
 		print(" relationship.")
 
+def plotresult(x, y, res, outfile):
+	import matplotlib.pyplot as plt
+	plt.figure()
+	plt.scatter(x, y, label="data")
+	x_ = numpy.linspace(min(x), max(x))
+	y_fit = res.slope * x_ + res.intercept
+	plt.title("p-value: {:.3g}, corr. coeff.: {:.3g}".format(res.pvalue, res.rvalue))
+	plt.plot(x_, y_fit, "-k", label="slope: {:.3g}, intercept: {:.3g}".format(res.slope, res.intercept))
+	plt.legend()
+	plt.tight_layout()
+	plt.savefig(outfile)
+
 def main(argv):
 	significlvl = 0.05
+	outfile = ''
 	try:
-		opts, args = getopt.getopt(argv, "hs:")
+		opts, args = getopt.getopt(argv, "ho:s:")
 	except getopt.GetoptError:
 		usage()
 		sys.exit(2)
@@ -56,12 +68,16 @@ def main(argv):
 			sys.exit(0)
 		elif opt == "-s":
 			significlvl = float(arg)
+		elif opt == "-o":
+			outfile = arg
 
 	x, y = read2dstdin()
 	res = scipy.stats.linregress(x, y, alternative="two-sided")
 	reportregress(res)
 	reportsignif(res.pvalue, significlvl)
 	reportcorr(res.rvalue)
+	if outfile:
+		plotresult(x, y, res, outfile)
 
 if __name__ == "__main__":
 	main(sys.argv[1:])