在机器学习中,我们时常会碰到需要给属性增加字段的情况。譬如有x、y两个属性,当结果倾向于线性时,我们可以很简单的通过线性回归得到模型。但很多时候,线性(在数学上称为多元一次方程),线性是拟合不了结果的。
往往,我们就需要在给定的几个属性上,通过增加属性来尝试能否拟合。那么原本只有两列,x、y,我们增加2次方的属性后,就会变成x、y、x^2、x*y、y^2,变成了5个属性,根据以往经验,我们知道通过这5个属性是能拟合出曲线。
2次方时,我们还能很简单的写出来所有的组合形式,但是当5次方时,原本有4列时,我们该增加多少列,增加的列该怎么计算呢。这就有点麻烦了,譬如(x+y+z)^3展开后就是x^3+y^3+z^3+3xy^2+3xz^2+3x^2y+3yz^2+3x^2z+3y^2z+6xyz. 去掉系数后,就是我们需要追加的所有列了。我们这篇就是做一个程序,来通过给定的m列,n次方,来给出所有的组合形式。
譬如m为2,n也为2,那么我们给出结果组合:[{0,2}, {1,1}, {2,0}],代表追加3列,第一列是x^0 * y^2,第二列是x^1 * y^1,第三列是x^2 * y^0.
通过观察我们发现,我们需要做的是求这样的方程的所有解:X1+X2+X3+……+Xm = N。其中0<=X<=n。
那么解法就是,我们可以定义一个int[m],该数组共有m个元素,每个元素的取值范围在0到n之间,并且该数组的所有元素的和等于n即可。
直接看程序:
/**
* @author wuweifeng wrote on 2018/6/4.
*/
public class LineAdder {
private static int lines = 3;
private static int power = 5;
private static int[] resultArray;
public static void main(String[] args) {
resultArray = new int[lines];
deal(0);
}
public static void deal(int m) {
for (int i = 0; i <= power; i++) {
resultArray[m] = i;
if (m == lines - 1) {
//如果找到一个解
if (check()) {
print();
return;
}
} else {
deal(m + 1);
}
}
}
/**
* 判断是否符合结果
*
* @return 是否符合
*/
private static boolean check() {
int total = 0;
for (int one : resultArray) {
total += one;
}
return power == total;
}
private static void print() {
for (int one : resultArray) {
System.out.print(one);
}
System.out.print("\n");
}
}
结果是:
005
014
023
032
041
050
104
113
122
131
140
203
212
221
230
302
311
320
401
410
500
这就是有3列,并且希望求出5次方时的所有组合的答案。
下面我们将它优化一下,让他能处理文本,能处理一行一行的数据,直接把列追加在文本上。
直接上代码:
package ploy;
import java.util.ArrayList;
import java.util.List;
/**
* @author wuweifeng wrote on 2018/6/4.
*/
public class LineAdder {
private int lines = 3;
private int power = 5;
private List<int[]> resultList = new ArrayList<>();
private int[] resultArray;
public List<int[]> lineAdd(int lines, int power) {
resultArray = new int[lines];
this.lines = lines;
this.power = power;
deal(0);
return resultList;
}
private void deal(int m) {
for (int i = 0; i <= power; i++) {
resultArray[m] = i;
if (m == lines - 1) {
//如果找到一个解
if (check()) {
print();
return;
}
} else {
deal(m + 1);
}
}
}
/**
* 判断是否符合结果
*
* @return 是否符合
*/
private boolean check() {
int total = 0;
for (int one : resultArray) {
total += one;
}
return power == total;
}
private void print() {
for (int one : resultArray) {
System.out.print(one);
}
System.out.print("\n");
int[] temp = new int[resultArray.length];
System.arraycopy(resultArray, 0, temp, 0, resultArray.length);
resultList.add(temp);
}
}
package ploy;
import java.io.*;
import java.util.List;
/**
* @author wuweifeng wrote on 2018/6/5.
*/
public class TextDeal {
public static void main(String[] args) throws IOException {
new TextDeal().linePower("/Users/wuwf/Downloads/ml_data/1逻辑回归入门/train_test_deal.csv",
"/Users/wuwf/Downloads/ml_data/1逻辑回归入门/train_test_deal-3.csv", 3, 1, 2, 3, 6);
}
/**
* @param filePath
* 文件的路径
* @param outputPath
* 输出文件的路径
* @param power
* 要做几次方
* @param lineNums
* 都有哪几列,需要power,不填默认所有列。从第0列开始
*/
public void linePower(String filePath, String outputPath, Integer power, Integer... lineNums) throws IOException {
BufferedReader reader = buildReader(filePath);
BufferedWriter writer = buildWriter(outputPath);
addCSVHeader(reader, writer, power, lineNums);
}
private Integer[] getLineNums(String[] lines, Integer... lineNums) {
//为null,则是所有列
if (lineNums == null) {
lineNums = new Integer[lines.length];
for (int i = 0; i < lines.length; i++) {
lineNums[i] = i;
}
}
return lineNums;
}
private List<int[]> getAddList(int power, Integer... lineNums) {
LineAdder lineAdder = new LineAdder();
//计算共需增加多少列
return lineAdder.lineAdd(lineNums.length, power);
}
/**
* 给header里增加相应的列名,都在第一行
*/
private void addCSVHeader(BufferedReader reader, BufferedWriter writer, Integer power, Integer... lineNums)
throws IOException {
//读取第一行
String header = reader.readLine();
//所有的列名
String[] lines = header.split(",");
lineNums = getLineNums(lines, lineNums);
//计算共需增加多少列
List<int[]> list = getAddList(power, lineNums);
String[] addLines = new String[list.size()];
String[] needLines = new String[lineNums.length];
for (int i = 0; i < lineNums.length; i++) {
needLines[i] = lines[lineNums[i]];
}
//设置每一列的名字
for (int i = 0; i < list.size(); i++) {
int[] array = list.get(i);
String s = "";
for (int j = 0; j < array.length; j++) {
s += needLines[j] + array[j];
}
addLines[i] = s;
}
for (String addLine : addLines) {
header += "," + addLine;
}
//将新增的列,写入header文件
writer.write(header);
writer.newLine();
writer.flush();
String oneLine;
while ((oneLine = reader.readLine()) != null) {
addLines = new String[list.size()];
lines = oneLine.split(",");
needLines = new String[lineNums.length];
for (int i = 0; i < lineNums.length; i++) {
needLines[i] = lines[lineNums[i]];
}
//设置每一列的值
for (int i = 0; i < list.size(); i++) {
int[] array = list.get(i);
double s = 1;
try {
for (int j = 0; j < array.length; j++) {
//譬如a,b,对应02时,该列就是a的0次方乘以b的2次方
s *= Math.pow(Double.valueOf(needLines[j]), array[j]);
}
addLines[i] = s + "";
} catch (Exception e) {
addLines[i] = "?";
}
}
for (String addLine : addLines) {
oneLine += "," + addLine;
}
writer.write(oneLine);
//写入相关文件
writer.newLine();
}
//将新增的列,写入header文件
writer.flush();
//关闭流
reader.close();
writer.close();
}
private BufferedReader buildReader(String filePath) {
try {
return new BufferedReader(new FileReader(new File(filePath)));
} catch (FileNotFoundException e) {
e.printStackTrace();
return null;
}
}
private BufferedWriter buildWriter(String outputPath) {
//写入相应的文件
try {
return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputPath), "utf-8"));
} catch (UnsupportedEncodingException | FileNotFoundException e) {
e.printStackTrace();
return null;
}
}
}
假如csv文件是这样的
a,b 1,2 2,3
4,5
运行后,结果是
a,b,a0b2,a1b1,a2b0 1,2,4.0,2.0,1.0 2,3,9.0,6.0,4.0 4,5,25.0,20.0,16.0
可以看到已经完成了做2次方的展开。
这个类,可以完成任意次方的模拟及计算。