仅列出核心代码:
1.plotData.m
ind1 = find(y==1); ind0 = find(y==0);
plot(X(ind1, 1), X(ind1, 2), 'k+','LineWidth', 2, 'MarkerSize', 7);
plot(X(ind0, 1), X(ind0, 2), 'ko', 'MarkerFaceColor', 'y', 'MarkerSize', 7);
2.sigmoid.m
g = 1 ./ (ones(size(z)) + exp(-z));
3.costFunction.m
h = sigmoid(X * theta); % h_theta(X) : m*1
J = (-log(h.')*y - log(ones(1, m) - h.')*(ones(m, 1) - y)) / m;
grad = (X.' * (h - y)) /m;
4.predict.m
h = sigmoid(X * theta);
p = (h >= 0.5);
5.costFunctionReg.m
h = sigmoid(X * theta); % h_theta(X) : m*1
% Cost func
J = (-log(h.')*y - log(ones(1, m) - h.')*(ones(m, 1) - y)) / m ...
+(lambda/(2*m)) * sum(theta(2:end).^2);% Gradient
grad(1) = (X(:, 1).' * (h - y)) /m;grad(2:end) = (X(:, 2:end).' * (h - y)) /m ...
+ (lambda/m) * theta(2:end);
没有评论:
发表评论