commit 2e9d87b259dcacf7c74ce8ca85554b05d08533d0
parent 0cc78918862ea2153868166130c2c4bf51ee5ae4
Author: Anders Damsgaard <anders@adamsgaard.dk>
Date: Wed, 24 Nov 2021 15:38:45 +0100
linregress: add option to save plot to disk
Diffstat:
1 file changed, 20 insertions(+), 4 deletions(-)
diff --git a/linregress b/linregress
@@ -3,8 +3,8 @@ import sys, getopt
import numpy
import scipy.stats
-def usage(arg0):
- print("usage: {} [-h] [-s significance_level]".format(sys.argv[0]))
+def usage():
+ print("usage: {} [-h] [-o outfile] [-s significance_level]".format(sys.argv[0]))
def read2dstdin():
input = numpy.loadtxt(sys.stdin)
@@ -25,7 +25,6 @@ def reportsignif(p, significlvl):
print("greater or equal than the significance level ({:g}),".format(significlvl))
print("which means that the null hypothesis of zero slope CANNOT be rejected.\n")
-
def reportcorr(r):
print("The correlation coefficient (r-value) denotes ", end="")
if (abs(r) < 0.01):
@@ -43,10 +42,23 @@ def reportcorr(r):
print(" positive" if r > 0.0 else " negative")
print(" relationship.")
+def plotresult(x, y, res, outfile):
+ import matplotlib.pyplot as plt
+ plt.figure()
+ plt.scatter(x, y, label="data")
+ x_ = numpy.linspace(min(x), max(x))
+ y_fit = res.slope * x_ + res.intercept
+ plt.title("p-value: {:.3g}, corr. coeff.: {:.3g}".format(res.pvalue, res.rvalue))
+ plt.plot(x_, y_fit, "-k", label="slope: {:.3g}, intercept: {:.3g}".format(res.slope, res.intercept))
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(outfile)
+
def main(argv):
significlvl = 0.05
+ outfile = ''
try:
- opts, args = getopt.getopt(argv, "hs:")
+ opts, args = getopt.getopt(argv, "ho:s:")
except getopt.GetoptError:
usage()
sys.exit(2)
@@ -56,12 +68,16 @@ def main(argv):
sys.exit(0)
elif opt == "-s":
significlvl = float(arg)
+ elif opt == "-o":
+ outfile = arg
x, y = read2dstdin()
res = scipy.stats.linregress(x, y, alternative="two-sided")
reportregress(res)
reportsignif(res.pvalue, significlvl)
reportcorr(res.rvalue)
+ if outfile:
+ plotresult(x, y, res, outfile)
if __name__ == "__main__":
main(sys.argv[1:])