การเรียนรู้ของเครื่อง - การถดถอยเชิงเส้น
การถดถอย
คำว่าการถดถอยจะใช้เมื่อคุณพยายามค้นหาความสัมพันธ์ระหว่างตัวแปร
ในแมชชีนเลิร์นนิงและในการสร้างแบบจำลองทางสถิติ ความสัมพันธ์นั้นถูกใช้เพื่อทำนายผลลัพธ์ของเหตุการณ์ในอนาคต
การถดถอยเชิงเส้น
การถดถอยเชิงเส้นใช้ความสัมพันธ์ระหว่างจุดข้อมูลเพื่อวาดเส้นตรงผ่านจุดข้อมูลทั้งหมด
บรรทัดนี้สามารถใช้ทำนายค่าในอนาคตได้
ในแมชชีนเลิร์นนิง การทำนายอนาคตมีความสำคัญมาก
มันทำงานอย่างไร?
Python มีวิธีการค้นหาความสัมพันธ์ระหว่างจุดข้อมูลและวาดเส้นการถดถอยเชิงเส้น เราจะแสดงวิธีใช้วิธีการเหล่านี้แทนการใช้สูตรทางคณิตศาสตร์
ในตัวอย่างด้านล่าง แกน x แสดงถึงอายุ และแกน y แสดงถึงความเร็ว เราได้จดทะเบียนรถอายุและความเร็วไว้ 13 คัน ขณะที่กำลังผ่านด่านเก็บค่าผ่านทาง ให้เราดูว่าข้อมูลที่รวบรวมสามารถนำมาใช้ในการถดถอยเชิงเส้นได้หรือไม่:
ตัวอย่าง
เริ่มต้นด้วยการวาดแผนภาพกระจาย:
import matplotlib.pyplot as plt
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
plt.scatter(x, y)
plt.show()
ผลลัพธ์:
ตัวอย่าง
นำเข้าscipy
และวาดเส้นของการถดถอยเชิงเส้น:
import matplotlib.pyplot as plt
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc, x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
ผลลัพธ์:
ตัวอย่างที่อธิบาย
นำเข้าโมดูลที่คุณต้องการ
คุณสามารถเรียนรู้เกี่ยวกับโมดูล Matplotlib ได้ใน บทช่วย สอน Matplotlibของ เรา
คุณสามารถเรียนรู้เกี่ยวกับโมดูล SciPy ได้ใน บทช่วย สอน SciPyของ เรา
import matplotlib.pyplot as plt
from scipy
import stats
สร้างอาร์เรย์ที่แสดงค่าของแกน x และ y:
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y = [99,86,87,88,111,86,103,87,94,78,77,85,86]
ดำเนินการวิธีการที่คืนค่าคีย์ที่สำคัญของการถดถอยเชิงเส้น:
slope, intercept, r,
p, std_err = stats.linregress(x, y)
สร้างฟังก์ชันที่ใช้ ค่า slope
และ
intercept
เพื่อส่งกลับค่าใหม่ ค่าใหม่นี้แสดงถึงตำแหน่งที่จะวางค่า x ที่สอดคล้องกันบนแกน y:
def myfunc(x):
return slope * x + intercept
เรียกใช้แต่ละค่าของอาร์เรย์ x ผ่านฟังก์ชัน ซึ่งจะส่งผลให้อาร์เรย์ใหม่มีค่าใหม่สำหรับแกน y:
mymodel = list(map(myfunc, x))
วาดพล็อตกระจายดั้งเดิม:
plt.scatter(x, y)
ลากเส้นการถดถอยเชิงเส้น:
plt.plot(x, mymodel)
แสดงไดอะแกรม:
plt.show()
R สำหรับความสัมพันธ์
สิ่งสำคัญคือต้องรู้ว่าความสัมพันธ์ระหว่างค่าของแกน x กับค่าของแกน y เป็นอย่างไร ถ้าไม่มีความสัมพันธ์ การถดถอยเชิงเส้นจะไม่สามารถใช้ทำนายอะไรได้
ความสัมพันธ์นี้เรียกว่าสัมประสิทธิ์สห
r
สัมพันธ์
ค่าr
มีตั้งแต่ -1 ถึง 1 โดยที่ 0 หมายถึงไม่มีความสัมพันธ์ และ 1 (และ -1) หมายถึงเกี่ยวข้องกัน 100%
Python และโมดูล Scipy จะคำนวณค่านี้ให้คุณ สิ่งที่คุณต้องทำคือป้อนค่าด้วยค่า x และ y
ตัวอย่าง
ข้อมูลของฉันเหมาะสมกับการถดถอยเชิงเส้นแค่ไหน?
from scipy import stats
x =
[5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
หมายเหตุ:ผลลัพธ์ -0.76 แสดงว่ามีความสัมพันธ์ไม่สมบูรณ์แบบ แต่บ่งชี้ว่าเราสามารถใช้การถดถอยเชิงเส้นในการคาดคะเนในอนาคตได้
ทำนายค่าในอนาคต
ตอนนี้เราสามารถใช้ข้อมูลที่รวบรวมมาเพื่อทำนายค่าในอนาคตได้
ตัวอย่าง ให้เราลองทำนายความเร็วของรถอายุ 10 ปี
ในการทำเช่นนั้น เราต้องการmyfunc()
ฟังก์ชันเดียวกันจากตัวอย่างด้านบน:
def myfunc(x):
return slope * x + intercept
ตัวอย่าง
ทำนายความเร็วของรถอายุ 10 ปี:
from scipy import stats
x = [5,7,8,7,2,17,2,9,4,11,12,9,6]
y =
[99,86,87,88,111,86,103,87,94,78,77,85,86]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
def myfunc(x):
return slope * x + intercept
speed = myfunc(10)
print(speed)
ตัวอย่างทำนายความเร็วที่ 85.6 ซึ่งเราสามารถอ่านได้จากแผนภาพ:
หุ่นไม่ดี?
ให้เราสร้างตัวอย่างที่การถดถอยเชิงเส้นไม่ใช่วิธีที่ดีที่สุดในการทำนายค่าในอนาคต
ตัวอย่าง
ค่าเหล่านี้สำหรับแกน x และ y ควรส่งผลให้การถดถอยเชิงเส้นไม่เหมาะสมอย่างยิ่ง:
import matplotlib.pyplot as plt
from scipy import stats
x = [89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope,
intercept, r, p, std_err = stats.linregress(x, y)
def
myfunc(x):
return slope * x + intercept
mymodel = list(map(myfunc,
x))
plt.scatter(x, y)
plt.plot(x, mymodel)
plt.show()
ผลลัพธ์:
และr
สำหรับความสัมพันธ์?
ตัวอย่าง
คุณควรได้รับr
ค่า ที่ต่ำมาก
import numpy
from scipy import stats
x =
[89,43,36,36,95,10,66,34,38,20,26,29,48,64,6,5,36,66,72,40]
y =
[21,46,3,35,67,95,53,72,58,10,26,34,90,33,38,20,56,2,47,15]
slope, intercept, r,
p, std_err = stats.linregress(x, y)
print(r)
ผลลัพธ์: 0.013 บ่งชี้ถึงความสัมพันธ์ที่แย่มาก และบอกเราว่าชุดข้อมูลนี้ไม่เหมาะสำหรับการถดถอยเชิงเส้น