import sys
import numpy as np
from pyspark.sql import SparkSession
D = 10 # Number of dimensions
# Read a batch of points from the input file into a NumPy matrix object. We operate on batches to
# make further computations faster.
# The data file contains lines of the form ... . We load each block of these
# into a NumPy array of size numLines * (D + 1) and pull out column 0 vs the others in gradient().
def readPointBatch(iterator):
strs = list(iterator)
matrix = np.zeros((len(strs), D + 1))
for i, s in enumerate(strs):
matrix[i] = np.fromstring(s.replace(',', ' '), dtype=np.float32, sep=' ')
return [matrix]
if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: logistic_regression ", file=sys.stderr)
exit(-1)
print("""WARN: This is a naive implementation of Logistic Regression and is
given as an example!
Please refer to examples/src/main/python/ml/logistic_regression_with_elastic_net.py
to see how ML's implementation is used.""", file=sys.stderr)
spark = SparkSession\
.builder\
.appName("PythonLR")\
.getOrCreate()
points = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])\
.mapPartitions(readPointBatch).cache()
iterations = int(sys.argv[2])
# Initialize w to a random value
w = 2 * np.random.ranf(size=D) - 1
print("Initial w: " + str(w))
# Compute logistic regression gradient for a matrix of data points
def gradient(matrix, w):
Y = matrix[:, 0] # point labels (first column of input file)
X = matrix[:, 1:] # point coordinates
# For each point (x, y), compute gradient function, then sum these up
return ((1.0 / (1.0 + np.exp(-Y * X.dot(w))) - 1.0) * Y * X.T).sum(1)
def add(x, y):
x += y
return x
for i in range(iterations):
print("On iteration %i" % (i + 1))
w -= points.map(lambda m: gradient(m, w)).reduce(add)
print("Final w: " + str(w))
spark.stop()
1. 其中map()与mapPartitions()函数区别:
map()函数是作用到数据集的每一个元素,所以它的目标是单个的元素
mapPartitions()函数作用到数据集的每个分区上,所以它的目标是迭代器(包含多个元素)
2. enumerate()函数
返回枚举对象
以下展示了使用 enumerate() 方法的实例:
>>>seasons = ['Spring', 'Summer', 'Fall', 'Winter']>>> list(enumerate(seasons))[(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]>>> list(enumerate(seasons, start=1)) # 小标从 1 开始[(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]
ps:readPointBatch()函数写的挺优雅的
3. numpy fromstring()将字符串转换成一维数组
Examples
>>> np.fromstring('1 2', dtype=int, sep=' ')
array([1, 2])
>>> np.fromstring('1, 2', dtype=int, sep=',')
array([1, 2])