• 热门专题

bag-of-wordsmodel的java实现

作者:  发布日期:2014-08-01 21:13:47
  • bag-of-words model的java实现
    为了验证paragraphVector的优势,需要拿bag-of-words model来对比。 实验数据:京东的评论,经人工挑选,分为“正面评论”和“负面评论”,中性的去掉。 分别拿这两个模型,来对每段“评论”做特征抽取,然后拿SVM来分类。 实验结果:400条训练,254条测试。bag-of-words模型的准确率是0.66,paraVector模型的准确率是0.84.

    下面给出bag-of-words model的实现。其实很简单,原理之前在《数学之美》看过。具体可以参考http://www.cnblogs.com/platero/archive/2012/12/03/2800251.html。

    训练数据: 1 文件good:正面评论 2 文件bad:负面评论 3 文件dict:其实就是good+bad,把正面评论和负面评论放在一起,主要遍历这个文件,找出所有词汇,生成词典。


    import java.io.BufferedReader;
    import java.io.BufferedWriter;
    import java.io.File;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.FileOutputStream;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.io.OutputStreamWriter;
    import java.io.UnsupportedEncodingException;
    import java.util.StringTokenizer;
    
    public class BowModel 
    {
    
    	
    	Dict dict;
    	DocFeatureFactory dff;
    	
    	public BowModel(String path) throws Throwable
    	{
    		dict = new Dict();
    		dict.loadFromLocalFile(path);		
    		dff = new DocFeatureFactory(dict.getWord2Index());
    	}
    	
    	
    
    	
    	
    	double[][] featureTable;
    	private void generateFeature(String docsFile,int docNum) throws IOException
    	{
    		featureTable = new double[docNum][];
    		int docIndex=0;
    		File file = new File(docsFile);
    		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
    		while(true)
    		{
    			String line=br.readLine();
    			if(line == null)
    				break;
    			featureTable[docIndex++] = dff.getFeature(line);
    		}
    		br.close();		
    	}
    	
    	private void nomalizeFeature()
    	{
    		double sum=0;
    		double var =0;
    		for(int col=0;col<featureTable[0].length;col++)//一列代表一个维度
    		{
    			sum =0;
    			for(int row=0;row<featureTable.length;row++)
    			{
    				sum+= featureTable[row][col];
    			}
    			sum/=featureTable.length;//均值
    			var =0;
    			for(int row=0;row<featureTable.length;row++)
    			{
    				var+= (featureTable[row][col]-sum)*(featureTable[row][col]-sum);
    			}
    			var = Math.sqrt(var/featureTable.length);//标准差
    			if(var == 0) continue;
    			for(int row=0;row<featureTable.length;row++)
    			{
    				featureTable[row][col] = (featureTable[row][col] -sum)/var;
    			}
    		}
    	}
    	
    	private void saveFeature(String path,String label) throws IOException
    	{
    		File file=new File(path);
    		BufferedWriter br= new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file)));
    		for(int i=0;i<featureTable.length;i++)
    		{
    			br.append(label+" ");
    			for(int j=0;j<featureTable[0].length;j++)
    			{
    				br.append(Integer.toString(j+1)+":"+featureTable[i][j]+" ");
    			}
    			br.append("
    ");
    		}
    		br.close();
    	}
    	
    	public void train() throws IOException
    	{
    		generateFeature("/media/linger/G/sources/comment/test/good",340);
    		nomalizeFeature();
    		saveFeature("svm_good","1");
    		
    		generateFeature("/media/linger/G/sources/comment/test/bad",314);
    		nomalizeFeature();
    		saveFeature("svm_bad","-1");
    	}
    	
    	
    	public static void main(String[] args) throws Throwable 
    	{
    		// TODO Auto-generated method stub
    		BowModel bm = new BowModel("/media/linger/G/sources/comment/test/dict");
    		bm.train();
    	}
    
    }
    



    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileInputStream;
    import java.io.FileNotFoundException;
    import java.io.IOException;
    import java.io.InputStreamReader;
    import java.io.UnsupportedEncodingException;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Hashtable;
    import java.util.StringTokenizer;
    
    public class Dict 
    {
    	HashMap<String,Integer> word2Index =null;
    	Hashtable<String,Integer> word2Count = null;
    	void loadFromLocalFile(String path) throws IOException
    	{
    		word2Index = new HashMap<String,Integer>();
    		word2Count = new Hashtable<String,Integer>();
    		int index = 0;
    		File file = new File(path);
    		BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(file),"utf-8"));
    		while(true)
    		{
    			String line=br.readLine();
    			if(line == null)
    				break;
    			StringTokenizer tokenizer=new StringTokenizer(line," ");
    			while(tokenizer.hasMoreElements())
    			{
    				String term=tokenizer.nextToken();
    				if(word2Count.containsKey(term))
    				{
    					
    					int freq=word2Count.get(term)+1;
    					word2Count.put(term, freq);
    					
    				}
    				else
    				{
    					word2Count.put(term, 1);
    					word2Index.put(term, index++);
    				}
    			}
    		}
    		br.close();
    	}
    	
    	public HashMap<String,Integer> getWord2Index() throws Throwable
    	{
    		if(word2Index==null)
    			throw new Exception("has not loaded file!");
    		return word2Index;
    	}
    	
    	public static void main(String[] args) 
    	{
    		// TODO Auto-generated method stub
    
    	}
    
    }
    



    import java.util.HashMap;
    import java.util.StringTokenizer;
    
    public class DocFeatureFactory 
    {
    	HashMap<String,Integer> word2Index;
    	double[] feature;
    	int dim;
    	public DocFeatureFactory(HashMap<String,Integer> w2i)
    	{
    		word2Index = w2i;
    		dim = w2i.size();
    	}
    	
    	double[] getFeature(String doc)
    	{
    		feature = new double[dim];
    		StringTokenizer tokenizer=new StringTokenizer(doc," ");
    		while(tokenizer.hasMoreElements())
    		{
    			String term =tokenizer.nextToken();
    			feature[word2Index.get(term)]++;
    		}	
    		return feature;
    	}
    	
    	public static void main(String[] args) 
    	{
    		// TODO Auto-generated method stub
    
    	}
    
    }
    

延伸阅读:

About IT165 - 广告服务 - 隐私声明 - 版权申明 - 免责条款 - 网站地图 - 网友投稿 - 联系方式
本站内容来自于互联网,仅供用于网络技术学习,学习中请遵循相关法律法规